7081e3da6e1d68961e84ddc66e19607a31799522
[sfrench/samba-autobuild/.git] / python / samba / samdb.py
1 # Unix SMB/CIFS implementation.
2 # Copyright (C) Jelmer Vernooij <jelmer@samba.org> 2007-2010
3 # Copyright (C) Matthias Dieter Wallnoefer 2009
4 #
5 # Based on the original in EJS:
6 # Copyright (C) Andrew Tridgell <tridge@samba.org> 2005
7 # Copyright (C) Giampaolo Lauria <lauria2@yahoo.com> 2011
8 #
9 # This program is free software; you can redistribute it and/or modify
10 # it under the terms of the GNU General Public License as published by
11 # the Free Software Foundation; either version 3 of the License, or
12 # (at your option) any later version.
13 #
14 # This program is distributed in the hope that it will be useful,
15 # but WITHOUT ANY WARRANTY; without even the implied warranty of
16 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
17 # GNU General Public License for more details.
18 #
19 # You should have received a copy of the GNU General Public License
20 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
21 #
22
23 """Convenience functions for using the SAM."""
24
25 import samba
26 import ldb
27 import time
28 import base64
29 import os
30 from samba import dsdb, dsdb_dns
31 from samba.ndr import ndr_unpack, ndr_pack
32 from samba.dcerpc import drsblobs, misc
33 from samba.common import normalise_int32
34
35 __docformat__ = "restructuredText"
36
37
38 class SamDB(samba.Ldb):
39     """The SAM database."""
40
41     hash_oid_name = {}
42     hash_well_known = {}
43
44     def __init__(self, url=None, lp=None, modules_dir=None, session_info=None,
45                  credentials=None, flags=0, options=None, global_schema=True,
46                  auto_connect=True, am_rodc=None):
47         self.lp = lp
48         if not auto_connect:
49             url = None
50         elif url is None and lp is not None:
51             url = lp.samdb_url()
52
53         self.url = url
54
55         super(SamDB, self).__init__(url=url, lp=lp, modules_dir=modules_dir,
56             session_info=session_info, credentials=credentials, flags=flags,
57             options=options)
58
59         if global_schema:
60             dsdb._dsdb_set_global_schema(self)
61
62         if am_rodc is not None:
63             dsdb._dsdb_set_am_rodc(self, am_rodc)
64
65     def connect(self, url=None, flags=0, options=None):
66         '''connect to the database'''
67         if self.lp is not None and not os.path.exists(url):
68             url = self.lp.private_path(url)
69         self.url = url
70
71         super(SamDB, self).connect(url=url, flags=flags,
72                 options=options)
73
74     def am_rodc(self):
75         '''return True if we are an RODC'''
76         return dsdb._am_rodc(self)
77
78     def am_pdc(self):
79         '''return True if we are an PDC emulator'''
80         return dsdb._am_pdc(self)
81
82     def domain_dn(self):
83         '''return the domain DN'''
84         return str(self.get_default_basedn())
85
86     def disable_account(self, search_filter):
87         """Disables an account
88
89         :param search_filter: LDAP filter to find the user (eg
90             samccountname=name)
91         """
92
93         flags = samba.dsdb.UF_ACCOUNTDISABLE
94         self.toggle_userAccountFlags(search_filter, flags, on=True)
95
96     def enable_account(self, search_filter):
97         """Enables an account
98
99         :param search_filter: LDAP filter to find the user (eg
100             samccountname=name)
101         """
102
103         flags = samba.dsdb.UF_ACCOUNTDISABLE | samba.dsdb.UF_PASSWD_NOTREQD
104         self.toggle_userAccountFlags(search_filter, flags, on=False)
105
106     def toggle_userAccountFlags(self, search_filter, flags, flags_str=None,
107                                 on=True, strict=False):
108         """Toggle_userAccountFlags
109
110         :param search_filter: LDAP filter to find the user (eg
111             samccountname=name)
112         :param flags: samba.dsdb.UF_* flags
113         :param on: on=True (default) => set, on=False => unset
114         :param strict: strict=False (default) ignore if no action is needed
115                  strict=True raises an Exception if...
116         """
117         res = self.search(base=self.domain_dn(), scope=ldb.SCOPE_SUBTREE,
118                           expression=search_filter, attrs=["userAccountControl"])
119         if len(res) == 0:
120                 raise Exception("Unable to find account where '%s'" % search_filter)
121         assert(len(res) == 1)
122         account_dn = res[0].dn
123
124         old_uac = int(res[0]["userAccountControl"][0])
125         if on:
126             if strict and (old_uac & flags):
127                 error = "Account flag(s) '%s' already set" % flags_str
128                 raise Exception(error)
129
130             new_uac = old_uac | flags
131         else:
132             if strict and not (old_uac & flags):
133                 error = "Account flag(s) '%s' already unset" % flags_str
134                 raise Exception(error)
135
136             new_uac = old_uac & ~flags
137
138         if old_uac == new_uac:
139             return
140
141         mod = """
142 dn: %s
143 changetype: modify
144 delete: userAccountControl
145 userAccountControl: %u
146 add: userAccountControl
147 userAccountControl: %u
148 """ % (account_dn, old_uac, new_uac)
149         self.modify_ldif(mod)
150
151     def force_password_change_at_next_login(self, search_filter):
152         """Forces a password change at next login
153
154         :param search_filter: LDAP filter to find the user (eg
155             samccountname=name)
156         """
157         res = self.search(base=self.domain_dn(), scope=ldb.SCOPE_SUBTREE,
158                           expression=search_filter, attrs=[])
159         if len(res) == 0:
160                 raise Exception('Unable to find user "%s"' % search_filter)
161         assert(len(res) == 1)
162         user_dn = res[0].dn
163
164         mod = """
165 dn: %s
166 changetype: modify
167 replace: pwdLastSet
168 pwdLastSet: 0
169 """ % (user_dn)
170         self.modify_ldif(mod)
171
172     def newgroup(self, groupname, groupou=None, grouptype=None,
173                  description=None, mailaddress=None, notes=None, sd=None,
174                  gidnumber=None, nisdomain=None):
175         """Adds a new group with additional parameters
176
177         :param groupname: Name of the new group
178         :param grouptype: Type of the new group
179         :param description: Description of the new group
180         :param mailaddress: Email address of the new group
181         :param notes: Notes of the new group
182         :param gidnumber: GID Number of the new group
183         :param nisdomain: NIS Domain Name of the new group
184         :param sd: security descriptor of the object
185         """
186
187         group_dn = "CN=%s,%s,%s" % (groupname, (groupou or "CN=Users"), self.domain_dn())
188
189         # The new user record. Note the reliance on the SAMLDB module which
190         # fills in the default informations
191         ldbmessage = {"dn": group_dn,
192             "sAMAccountName": groupname,
193             "objectClass": "group"}
194
195         if grouptype is not None:
196             ldbmessage["groupType"] = normalise_int32(grouptype)
197
198         if description is not None:
199             ldbmessage["description"] = description
200
201         if mailaddress is not None:
202             ldbmessage["mail"] = mailaddress
203
204         if notes is not None:
205             ldbmessage["info"] = notes
206
207         if gidnumber is not None:
208             ldbmessage["gidNumber"] = normalise_int32(gidnumber)
209
210         if nisdomain is not None:
211             ldbmessage["msSFU30Name"] = groupname
212             ldbmessage["msSFU30NisDomain"] = nisdomain
213
214         if sd is not None:
215             ldbmessage["nTSecurityDescriptor"] = ndr_pack(sd)
216
217         self.add(ldbmessage)
218
219     def deletegroup(self, groupname):
220         """Deletes a group
221
222         :param groupname: Name of the target group
223         """
224
225         groupfilter = "(&(sAMAccountName=%s)(objectCategory=%s,%s))" % (ldb.binary_encode(groupname), "CN=Group,CN=Schema,CN=Configuration", self.domain_dn())
226         self.transaction_start()
227         try:
228             targetgroup = self.search(base=self.domain_dn(), scope=ldb.SCOPE_SUBTREE,
229                                expression=groupfilter, attrs=[])
230             if len(targetgroup) == 0:
231                 raise Exception('Unable to find group "%s"' % groupname)
232             assert(len(targetgroup) == 1)
233             self.delete(targetgroup[0].dn)
234         except:
235             self.transaction_cancel()
236             raise
237         else:
238             self.transaction_commit()
239
240     def add_remove_group_members(self, groupname, members,
241                                   add_members_operation=True):
242         """Adds or removes group members
243
244         :param groupname: Name of the target group
245         :param members: list of group members
246         :param add_members_operation: Defines if its an add or remove
247             operation
248         """
249
250         groupfilter = "(&(sAMAccountName=%s)(objectCategory=%s,%s))" % (
251             ldb.binary_encode(groupname), "CN=Group,CN=Schema,CN=Configuration", self.domain_dn())
252
253         self.transaction_start()
254         try:
255             targetgroup = self.search(base=self.domain_dn(), scope=ldb.SCOPE_SUBTREE,
256                                expression=groupfilter, attrs=['member'])
257             if len(targetgroup) == 0:
258                 raise Exception('Unable to find group "%s"' % groupname)
259             assert(len(targetgroup) == 1)
260
261             modified = False
262
263             addtargettogroup = """
264 dn: %s
265 changetype: modify
266 """ % (str(targetgroup[0].dn))
267
268             for member in members:
269                 targetmember = self.search(base=self.domain_dn(), scope=ldb.SCOPE_SUBTREE,
270                                     expression="(|(sAMAccountName=%s)(CN=%s))" % (
271                     ldb.binary_encode(member), ldb.binary_encode(member)), attrs=[])
272
273                 if len(targetmember) != 1:
274                     raise Exception('Unable to find "%s". Operation cancelled.' % member)
275
276                 if add_members_operation is True and (targetgroup[0].get('member') is None or str(targetmember[0].dn) not in targetgroup[0]['member']):
277                     modified = True
278                     addtargettogroup += """add: member
279 member: %s
280 """ % (str(targetmember[0].dn))
281
282                 elif add_members_operation is False and (targetgroup[0].get('member') is not None and str(targetmember[0].dn) in targetgroup[0]['member']):
283                     modified = True
284                     addtargettogroup += """delete: member
285 member: %s
286 """ % (str(targetmember[0].dn))
287
288             if modified is True:
289                 self.modify_ldif(addtargettogroup)
290
291         except:
292             self.transaction_cancel()
293             raise
294         else:
295             self.transaction_commit()
296
297     def newuser(self, username, password,
298             force_password_change_at_next_login_req=False,
299             useusernameascn=False, userou=None, surname=None, givenname=None,
300             initials=None, profilepath=None, scriptpath=None, homedrive=None,
301             homedirectory=None, jobtitle=None, department=None, company=None,
302             description=None, mailaddress=None, internetaddress=None,
303             telephonenumber=None, physicaldeliveryoffice=None, sd=None,
304             setpassword=True, uidnumber=None, gidnumber=None, gecos=None,
305             loginshell=None, uid=None, nisdomain=None, unixhome=None):
306         """Adds a new user with additional parameters
307
308         :param username: Name of the new user
309         :param password: Password for the new user
310         :param force_password_change_at_next_login_req: Force password change
311         :param useusernameascn: Use username as cn rather that firstname +
312             initials + lastname
313         :param userou: Object container (without domainDN postfix) for new user
314         :param surname: Surname of the new user
315         :param givenname: First name of the new user
316         :param initials: Initials of the new user
317         :param profilepath: Profile path of the new user
318         :param scriptpath: Logon script path of the new user
319         :param homedrive: Home drive of the new user
320         :param homedirectory: Home directory of the new user
321         :param jobtitle: Job title of the new user
322         :param department: Department of the new user
323         :param company: Company of the new user
324         :param description: of the new user
325         :param mailaddress: Email address of the new user
326         :param internetaddress: Home page of the new user
327         :param telephonenumber: Phone number of the new user
328         :param physicaldeliveryoffice: Office location of the new user
329         :param sd: security descriptor of the object
330         :param setpassword: optionally disable password reset
331         :param uidnumber: RFC2307 Unix numeric UID of the new user
332         :param gidnumber: RFC2307 Unix primary GID of the new user
333         :param gecos: RFC2307 Unix GECOS field of the new user
334         :param loginshell: RFC2307 Unix login shell of the new user
335         :param uid: RFC2307 Unix username of the new user
336         :param nisdomain: RFC2307 Unix NIS domain of the new user
337         :param unixhome: RFC2307 Unix home directory of the new user
338         """
339
340         displayname = ""
341         if givenname is not None:
342             displayname += givenname
343
344         if initials is not None:
345             displayname += ' %s.' % initials
346
347         if surname is not None:
348             displayname += ' %s' % surname
349
350         cn = username
351         if useusernameascn is None and displayname is not "":
352             cn = displayname
353
354         user_dn = "CN=%s,%s,%s" % (cn, (userou or "CN=Users"), self.domain_dn())
355
356         dnsdomain = ldb.Dn(self, self.domain_dn()).canonical_str().replace("/", "")
357         user_principal_name = "%s@%s" % (username, dnsdomain)
358         # The new user record. Note the reliance on the SAMLDB module which
359         # fills in the default informations
360         ldbmessage = {"dn": user_dn,
361                       "sAMAccountName": username,
362                       "userPrincipalName": user_principal_name,
363                       "objectClass": "user"}
364
365         if surname is not None:
366             ldbmessage["sn"] = surname
367
368         if givenname is not None:
369             ldbmessage["givenName"] = givenname
370
371         if displayname is not "":
372             ldbmessage["displayName"] = displayname
373             ldbmessage["name"] = displayname
374
375         if initials is not None:
376             ldbmessage["initials"] = '%s.' % initials
377
378         if profilepath is not None:
379             ldbmessage["profilePath"] = profilepath
380
381         if scriptpath is not None:
382             ldbmessage["scriptPath"] = scriptpath
383
384         if homedrive is not None:
385             ldbmessage["homeDrive"] = homedrive
386
387         if homedirectory is not None:
388             ldbmessage["homeDirectory"] = homedirectory
389
390         if jobtitle is not None:
391             ldbmessage["title"] = jobtitle
392
393         if department is not None:
394             ldbmessage["department"] = department
395
396         if company is not None:
397             ldbmessage["company"] = company
398
399         if description is not None:
400             ldbmessage["description"] = description
401
402         if mailaddress is not None:
403             ldbmessage["mail"] = mailaddress
404
405         if internetaddress is not None:
406             ldbmessage["wWWHomePage"] = internetaddress
407
408         if telephonenumber is not None:
409             ldbmessage["telephoneNumber"] = telephonenumber
410
411         if physicaldeliveryoffice is not None:
412             ldbmessage["physicalDeliveryOfficeName"] = physicaldeliveryoffice
413
414         if sd is not None:
415             ldbmessage["nTSecurityDescriptor"] = ndr_pack(sd)
416
417         ldbmessage2 = None
418         if any(map(lambda b: b is not None, (uid, uidnumber, gidnumber, gecos,
419                 loginshell, nisdomain, unixhome))):
420             ldbmessage2 = ldb.Message()
421             ldbmessage2.dn = ldb.Dn(self, user_dn)
422             ldbmessage2["objectClass"] = ldb.MessageElement('posixAccount', ldb.FLAG_MOD_ADD, 'objectClass')
423             if uid is not None:
424                 ldbmessage2["uid"] = ldb.MessageElement(str(uid), ldb.FLAG_MOD_REPLACE, 'uid')
425             if uidnumber is not None:
426                 ldbmessage2["uidNumber"] = ldb.MessageElement(str(uidnumber), ldb.FLAG_MOD_REPLACE, 'uidNumber')
427             if gidnumber is not None:
428                 ldbmessage2["gidNumber"] = ldb.MessageElement(str(gidnumber), ldb.FLAG_MOD_REPLACE, 'gidNumber')
429             if gecos is not None:
430                 ldbmessage2["gecos"] = ldb.MessageElement(str(gecos), ldb.FLAG_MOD_REPLACE, 'gecos')
431             if loginshell is not None:
432                 ldbmessage2["loginShell"] = ldb.MessageElement(str(loginshell), ldb.FLAG_MOD_REPLACE, 'loginShell')
433             if unixhome is not None:
434                 ldbmessage2["unixHomeDirectory"] = ldb.MessageElement(
435                     str(unixhome), ldb.FLAG_MOD_REPLACE, 'unixHomeDirectory')
436             if nisdomain is not None:
437                 ldbmessage2["msSFU30NisDomain"] = ldb.MessageElement(
438                     str(nisdomain), ldb.FLAG_MOD_REPLACE, 'msSFU30NisDomain')
439                 ldbmessage2["msSFU30Name"] = ldb.MessageElement(
440                     str(username), ldb.FLAG_MOD_REPLACE, 'msSFU30Name')
441                 ldbmessage2["unixUserPassword"] = ldb.MessageElement(
442                     'ABCD!efgh12345$67890', ldb.FLAG_MOD_REPLACE,
443                     'unixUserPassword')
444
445         self.transaction_start()
446         try:
447             self.add(ldbmessage)
448             if ldbmessage2:
449                 self.modify(ldbmessage2)
450
451             # Sets the password for it
452             if setpassword:
453                 self.setpassword("(samAccountName=%s)" % ldb.binary_encode(username), password,
454                                  force_password_change_at_next_login_req)
455         except:
456             self.transaction_cancel()
457             raise
458         else:
459             self.transaction_commit()
460
461
462     def deleteuser(self, username):
463         """Deletes a user
464
465         :param username: Name of the target user
466         """
467
468         filter = "(&(sAMAccountName=%s)(objectCategory=%s,%s))" % (ldb.binary_encode(username), "CN=Person,CN=Schema,CN=Configuration", self.domain_dn())
469         self.transaction_start()
470         try:
471             target = self.search(base=self.domain_dn(), scope=ldb.SCOPE_SUBTREE,
472                                  expression=filter, attrs=[])
473             if len(target) == 0:
474                 raise Exception('Unable to find user "%s"' % username)
475             assert(len(target) == 1)
476             self.delete(target[0].dn)
477         except:
478             self.transaction_cancel()
479             raise
480         else:
481             self.transaction_commit()
482
483     def setpassword(self, search_filter, password,
484             force_change_at_next_login=False, username=None):
485         """Sets the password for a user
486
487         :param search_filter: LDAP filter to find the user (eg
488             samccountname=name)
489         :param password: Password for the user
490         :param force_change_at_next_login: Force password change
491         """
492         self.transaction_start()
493         try:
494             res = self.search(base=self.domain_dn(), scope=ldb.SCOPE_SUBTREE,
495                               expression=search_filter, attrs=[])
496             if len(res) == 0:
497                 raise Exception('Unable to find user "%s"' % (username or search_filter))
498             if len(res) > 1:
499                 raise Exception('Matched %u multiple users with filter "%s"' % (len(res), search_filter))
500             user_dn = res[0].dn
501             pw = unicode('"' + password + '"', 'utf-8').encode('utf-16-le')
502             setpw = """
503 dn: %s
504 changetype: modify
505 replace: unicodePwd
506 unicodePwd:: %s
507 """ % (user_dn, base64.b64encode(pw))
508
509             self.modify_ldif(setpw)
510
511             if force_change_at_next_login:
512                 self.force_password_change_at_next_login(
513                   "(distinguishedName=" + str(user_dn) + ")")
514
515             #  modify the userAccountControl to remove the disabled bit
516             self.enable_account(search_filter)
517         except:
518             self.transaction_cancel()
519             raise
520         else:
521             self.transaction_commit()
522
523     def setexpiry(self, search_filter, expiry_seconds, no_expiry_req=False):
524         """Sets the account expiry for a user
525
526         :param search_filter: LDAP filter to find the user (eg
527             samaccountname=name)
528         :param expiry_seconds: expiry time from now in seconds
529         :param no_expiry_req: if set, then don't expire password
530         """
531         self.transaction_start()
532         try:
533             res = self.search(base=self.domain_dn(), scope=ldb.SCOPE_SUBTREE,
534                           expression=search_filter,
535                           attrs=["userAccountControl", "accountExpires"])
536             if len(res) == 0:
537                 raise Exception('Unable to find user "%s"' % search_filter)
538             assert(len(res) == 1)
539             user_dn = res[0].dn
540
541             userAccountControl = int(res[0]["userAccountControl"][0])
542             accountExpires     = int(res[0]["accountExpires"][0])
543             if no_expiry_req:
544                 userAccountControl = userAccountControl | 0x10000
545                 accountExpires = 0
546             else:
547                 userAccountControl = userAccountControl & ~0x10000
548                 accountExpires = samba.unix2nttime(expiry_seconds + int(time.time()))
549
550             setexp = """
551 dn: %s
552 changetype: modify
553 replace: userAccountControl
554 userAccountControl: %u
555 replace: accountExpires
556 accountExpires: %u
557 """ % (user_dn, userAccountControl, accountExpires)
558
559             self.modify_ldif(setexp)
560         except:
561             self.transaction_cancel()
562             raise
563         else:
564             self.transaction_commit()
565
566     def set_domain_sid(self, sid):
567         """Change the domain SID used by this LDB.
568
569         :param sid: The new domain sid to use.
570         """
571         dsdb._samdb_set_domain_sid(self, sid)
572
573     def get_domain_sid(self):
574         """Read the domain SID used by this LDB. """
575         return dsdb._samdb_get_domain_sid(self)
576
577     domain_sid = property(get_domain_sid, set_domain_sid,
578         "SID for the domain")
579
580     def set_invocation_id(self, invocation_id):
581         """Set the invocation id for this SamDB handle.
582
583         :param invocation_id: GUID of the invocation id.
584         """
585         dsdb._dsdb_set_ntds_invocation_id(self, invocation_id)
586
587     def get_invocation_id(self):
588         """Get the invocation_id id"""
589         return dsdb._samdb_ntds_invocation_id(self)
590
591     invocation_id = property(get_invocation_id, set_invocation_id,
592         "Invocation ID GUID")
593
594     def get_oid_from_attid(self, attid):
595         return dsdb._dsdb_get_oid_from_attid(self, attid)
596
597     def get_attid_from_lDAPDisplayName(self, ldap_display_name,
598             is_schema_nc=False):
599         '''return the attribute ID for a LDAP attribute as an integer as found in DRSUAPI'''
600         return dsdb._dsdb_get_attid_from_lDAPDisplayName(self,
601             ldap_display_name, is_schema_nc)
602
603     def get_syntax_oid_from_lDAPDisplayName(self, ldap_display_name):
604         '''return the syntax OID for a LDAP attribute as a string'''
605         return dsdb._dsdb_get_syntax_oid_from_lDAPDisplayName(self, ldap_display_name)
606
607     def get_systemFlags_from_lDAPDisplayName(self, ldap_display_name):
608         '''return the systemFlags for a LDAP attribute as a integer'''
609         return dsdb._dsdb_get_systemFlags_from_lDAPDisplayName(self, ldap_display_name)
610
611     def get_linkId_from_lDAPDisplayName(self, ldap_display_name):
612         '''return the linkID for a LDAP attribute as a integer'''
613         return dsdb._dsdb_get_linkId_from_lDAPDisplayName(self, ldap_display_name)
614
615     def get_lDAPDisplayName_by_attid(self, attid):
616         '''return the lDAPDisplayName from an integer DRS attribute ID'''
617         return dsdb._dsdb_get_lDAPDisplayName_by_attid(self, attid)
618
619     def get_backlink_from_lDAPDisplayName(self, ldap_display_name):
620         '''return the attribute name of the corresponding backlink from the name
621         of a forward link attribute. If there is no backlink return None'''
622         return dsdb._dsdb_get_backlink_from_lDAPDisplayName(self, ldap_display_name)
623
624     def set_ntds_settings_dn(self, ntds_settings_dn):
625         """Set the NTDS Settings DN, as would be returned on the dsServiceName
626         rootDSE attribute.
627
628         This allows the DN to be set before the database fully exists
629
630         :param ntds_settings_dn: The new DN to use
631         """
632         dsdb._samdb_set_ntds_settings_dn(self, ntds_settings_dn)
633
634     def get_ntds_GUID(self):
635         """Get the NTDS objectGUID"""
636         return dsdb._samdb_ntds_objectGUID(self)
637
638     def server_site_name(self):
639         """Get the server site name"""
640         return dsdb._samdb_server_site_name(self)
641
642     def host_dns_name(self):
643         """return the DNS name of this host"""
644         res = self.search(base='', scope=ldb.SCOPE_BASE, attrs=['dNSHostName'])
645         return res[0]['dNSHostName'][0]
646
647     def domain_dns_name(self):
648         """return the DNS name of the domain root"""
649         domain_dn = self.get_default_basedn()
650         return domain_dn.canonical_str().split('/')[0]
651
652     def forest_dns_name(self):
653         """return the DNS name of the forest root"""
654         forest_dn = self.get_root_basedn()
655         return forest_dn.canonical_str().split('/')[0]
656
657     def load_partition_usn(self, base_dn):
658         return dsdb._dsdb_load_partition_usn(self, base_dn)
659
660     def set_schema(self, schema, write_indices_and_attributes=True):
661         self.set_schema_from_ldb(schema.ldb, write_indices_and_attributes=write_indices_and_attributes)
662
663     def set_schema_from_ldb(self, ldb_conn, write_indices_and_attributes=True):
664         dsdb._dsdb_set_schema_from_ldb(self, ldb_conn, write_indices_and_attributes)
665
666     def dsdb_DsReplicaAttribute(self, ldb, ldap_display_name, ldif_elements):
667         '''convert a list of attribute values to a DRSUAPI DsReplicaAttribute'''
668         return dsdb._dsdb_DsReplicaAttribute(ldb, ldap_display_name, ldif_elements)
669
670     def dsdb_normalise_attributes(self, ldb, ldap_display_name, ldif_elements):
671         '''normalise a list of attribute values'''
672         return dsdb._dsdb_normalise_attributes(ldb, ldap_display_name, ldif_elements)
673
674     def get_attribute_from_attid(self, attid):
675         """ Get from an attid the associated attribute
676
677         :param attid: The attribute id for searched attribute
678         :return: The name of the attribute associated with this id
679         """
680         if len(self.hash_oid_name.keys()) == 0:
681             self._populate_oid_attid()
682         if self.hash_oid_name.has_key(self.get_oid_from_attid(attid)):
683             return self.hash_oid_name[self.get_oid_from_attid(attid)]
684         else:
685             return None
686
687     def _populate_oid_attid(self):
688         """Populate the hash hash_oid_name.
689
690         This hash contains the oid of the attribute as a key and
691         its display name as a value
692         """
693         self.hash_oid_name = {}
694         res = self.search(expression="objectClass=attributeSchema",
695                            controls=["search_options:1:2"],
696                            attrs=["attributeID",
697                            "lDAPDisplayName"])
698         if len(res) > 0:
699             for e in res:
700                 strDisplay = str(e.get("lDAPDisplayName"))
701                 self.hash_oid_name[str(e.get("attributeID"))] = strDisplay
702
703     def get_attribute_replmetadata_version(self, dn, att):
704         """Get the version field trom the replPropertyMetaData for
705         the given field
706
707         :param dn: The on which we want to get the version
708         :param att: The name of the attribute
709         :return: The value of the version field in the replPropertyMetaData
710             for the given attribute. None if the attribute is not replicated
711         """
712
713         res = self.search(expression="distinguishedName=%s" % dn,
714                             scope=ldb.SCOPE_SUBTREE,
715                             controls=["search_options:1:2"],
716                             attrs=["replPropertyMetaData"])
717         if len(res) == 0:
718             return None
719
720         repl = ndr_unpack(drsblobs.replPropertyMetaDataBlob,
721                             str(res[0]["replPropertyMetaData"]))
722         ctr = repl.ctr
723         if len(self.hash_oid_name.keys()) == 0:
724             self._populate_oid_attid()
725         for o in ctr.array:
726             # Search for Description
727             att_oid = self.get_oid_from_attid(o.attid)
728             if self.hash_oid_name.has_key(att_oid) and\
729                att.lower() == self.hash_oid_name[att_oid].lower():
730                 return o.version
731         return None
732
733     def set_attribute_replmetadata_version(self, dn, att, value,
734             addifnotexist=False):
735         res = self.search(expression="distinguishedName=%s" % dn,
736                             scope=ldb.SCOPE_SUBTREE,
737                             controls=["search_options:1:2"],
738                             attrs=["replPropertyMetaData"])
739         if len(res) == 0:
740             return None
741
742         repl = ndr_unpack(drsblobs.replPropertyMetaDataBlob,
743                             str(res[0]["replPropertyMetaData"]))
744         ctr = repl.ctr
745         now = samba.unix2nttime(int(time.time()))
746         found = False
747         if len(self.hash_oid_name.keys()) == 0:
748             self._populate_oid_attid()
749         for o in ctr.array:
750             # Search for Description
751             att_oid = self.get_oid_from_attid(o.attid)
752             if self.hash_oid_name.has_key(att_oid) and\
753                att.lower() == self.hash_oid_name[att_oid].lower():
754                 found = True
755                 seq = self.sequence_number(ldb.SEQ_NEXT)
756                 o.version = value
757                 o.originating_change_time = now
758                 o.originating_invocation_id = misc.GUID(self.get_invocation_id())
759                 o.originating_usn = seq
760                 o.local_usn = seq
761
762         if not found and addifnotexist and len(ctr.array) >0:
763             o2 = drsblobs.replPropertyMetaData1()
764             o2.attid = 589914
765             att_oid = self.get_oid_from_attid(o2.attid)
766             seq = self.sequence_number(ldb.SEQ_NEXT)
767             o2.version = value
768             o2.originating_change_time = now
769             o2.originating_invocation_id = misc.GUID(self.get_invocation_id())
770             o2.originating_usn = seq
771             o2.local_usn = seq
772             found = True
773             tab = ctr.array
774             tab.append(o2)
775             ctr.count = ctr.count + 1
776             ctr.array = tab
777
778         if found :
779             replBlob = ndr_pack(repl)
780             msg = ldb.Message()
781             msg.dn = res[0].dn
782             msg["replPropertyMetaData"] = ldb.MessageElement(replBlob,
783                                                 ldb.FLAG_MOD_REPLACE,
784                                                 "replPropertyMetaData")
785             self.modify(msg, ["local_oid:1.3.6.1.4.1.7165.4.3.14:0"])
786
787     def write_prefixes_from_schema(self):
788         dsdb._dsdb_write_prefixes_from_schema_to_ldb(self)
789
790     def get_partitions_dn(self):
791         return dsdb._dsdb_get_partitions_dn(self)
792
793     def get_nc_root(self, dn):
794         return dsdb._dsdb_get_nc_root(self, dn)
795
796     def get_wellknown_dn(self, nc_root, wkguid):
797         h_nc = self.hash_well_known.get(str(nc_root))
798         dn = None
799         if h_nc is not None:
800             dn = h_nc.get(wkguid)
801         if dn is None:
802             dn = dsdb._dsdb_get_wellknown_dn(self, nc_root, wkguid)
803             if dn is None:
804                 return dn
805             if h_nc is None:
806                 self.hash_well_known[str(nc_root)] = {}
807                 h_nc = self.hash_well_known[str(nc_root)]
808             h_nc[wkguid] = dn
809         return dn
810
811     def set_minPwdAge(self, value):
812         m = ldb.Message()
813         m.dn = ldb.Dn(self, self.domain_dn())
814         m["minPwdAge"] = ldb.MessageElement(value, ldb.FLAG_MOD_REPLACE, "minPwdAge")
815         self.modify(m)
816
817     def get_minPwdAge(self):
818         res = self.search(self.domain_dn(), scope=ldb.SCOPE_BASE, attrs=["minPwdAge"])
819         if len(res) == 0:
820             return None
821         elif not "minPwdAge" in res[0]:
822             return None
823         else:
824             return res[0]["minPwdAge"][0]
825
826     def set_minPwdLength(self, value):
827         m = ldb.Message()
828         m.dn = ldb.Dn(self, self.domain_dn())
829         m["minPwdLength"] = ldb.MessageElement(value, ldb.FLAG_MOD_REPLACE, "minPwdLength")
830         self.modify(m)
831
832     def get_minPwdLength(self):
833         res = self.search(self.domain_dn(), scope=ldb.SCOPE_BASE, attrs=["minPwdLength"])
834         if len(res) == 0:
835             return None
836         elif not "minPwdLength" in res[0]:
837             return None
838         else:
839             return res[0]["minPwdLength"][0]
840
841     def set_pwdProperties(self, value):
842         m = ldb.Message()
843         m.dn = ldb.Dn(self, self.domain_dn())
844         m["pwdProperties"] = ldb.MessageElement(value, ldb.FLAG_MOD_REPLACE, "pwdProperties")
845         self.modify(m)
846
847     def get_pwdProperties(self):
848         res = self.search(self.domain_dn(), scope=ldb.SCOPE_BASE, attrs=["pwdProperties"])
849         if len(res) == 0:
850             return None
851         elif not "pwdProperties" in res[0]:
852             return None
853         else:
854             return res[0]["pwdProperties"][0]
855
856     def set_dsheuristics(self, dsheuristics):
857         m = ldb.Message()
858         m.dn = ldb.Dn(self, "CN=Directory Service,CN=Windows NT,CN=Services,%s"
859                       % self.get_config_basedn().get_linearized())
860         if dsheuristics is not None:
861             m["dSHeuristics"] = ldb.MessageElement(dsheuristics,
862                 ldb.FLAG_MOD_REPLACE, "dSHeuristics")
863         else:
864             m["dSHeuristics"] = ldb.MessageElement([], ldb.FLAG_MOD_DELETE,
865                 "dSHeuristics")
866         self.modify(m)
867
868     def get_dsheuristics(self):
869         res = self.search("CN=Directory Service,CN=Windows NT,CN=Services,%s"
870                           % self.get_config_basedn().get_linearized(),
871                           scope=ldb.SCOPE_BASE, attrs=["dSHeuristics"])
872         if len(res) == 0:
873             dsheuristics = None
874         elif "dSHeuristics" in res[0]:
875             dsheuristics = res[0]["dSHeuristics"][0]
876         else:
877             dsheuristics = None
878
879         return dsheuristics
880
881     def create_ou(self, ou_dn, description=None, name=None, sd=None):
882         """Creates an organizationalUnit object
883         :param ou_dn: dn of the new object
884         :param description: description attribute
885         :param name: name atttribute
886         :param sd: security descriptor of the object, can be
887         an SDDL string or security.descriptor type
888         """
889         m = {"dn": ou_dn,
890              "objectClass": "organizationalUnit"}
891
892         if description:
893             m["description"] = description
894         if name:
895             m["name"] = name
896
897         if sd:
898             m["nTSecurityDescriptor"] = ndr_pack(sd)
899         self.add(m)
900
901     def sequence_number(self, seq_type):
902         """Returns the value of the sequence number according to the requested type
903         :param seq_type: type of sequence number
904          """
905         self.transaction_start()
906         try:
907             seq = super(SamDB, self).sequence_number(seq_type)
908         except:
909             self.transaction_cancel()
910             raise
911         else:
912             self.transaction_commit()
913         return seq
914
915     def get_dsServiceName(self):
916         '''get the NTDS DN from the rootDSE'''
917         res = self.search(base="", scope=ldb.SCOPE_BASE, attrs=["dsServiceName"])
918         return res[0]["dsServiceName"][0]
919
920     def get_serverName(self):
921         '''get the server DN from the rootDSE'''
922         res = self.search(base="", scope=ldb.SCOPE_BASE, attrs=["serverName"])
923         return res[0]["serverName"][0]
924
925     def dns_lookup(self, dns_name):
926         '''Do a DNS lookup in the database, returns the NDR database structures'''
927         return dsdb_dns.lookup(self, dns_name)
928
929     def dns_replace(self, dns_name, new_records):
930         '''Do a DNS modification on the database, sets the NDR database structures'''
931         return dsdb_dns.replace(self, dns_name, new_records)