python: Make top-level samba modules Python 3 compatible
[nivanova/samba-autobuild/.git] / python / samba / __init__.py
1 # Unix SMB/CIFS implementation.
2 # Copyright (C) Jelmer Vernooij <jelmer@samba.org> 2007-2008
3 #
4 # Based on the original in EJS:
5 # Copyright (C) Andrew Tridgell <tridge@samba.org> 2005
6 #
7 # This program is free software; you can redistribute it and/or modify
8 # it under the terms of the GNU General Public License as published by
9 # the Free Software Foundation; either version 3 of the License, or
10 # (at your option) any later version.
11 #
12 # This program is distributed in the hope that it will be useful,
13 # but WITHOUT ANY WARRANTY; without even the implied warranty of
14 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 # GNU General Public License for more details.
16 #
17 # You should have received a copy of the GNU General Public License
18 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
19 #
20
21 """Samba 4."""
22
23 __docformat__ = "restructuredText"
24
25 import os
26 import sys
27 import time
28 import ldb
29 from samba.compat import PY3
30 import samba.param
31 from samba import _glue
32 if not PY3:
33     from samba._ldb import Ldb as _Ldb
34 else:
35     # samba._ldb is not yet ported to Python 3
36     _Ldb = object
37
38
39 def source_tree_topdir():
40     """Return the top level source directory."""
41     paths = ["../../..", "../../../.."]
42     for p in paths:
43         topdir = os.path.normpath(os.path.join(os.path.dirname(__file__), p))
44         if os.path.exists(os.path.join(topdir, 'source4')):
45             return topdir
46     raise RuntimeError("unable to find top level source directory")
47
48
49 def in_source_tree():
50     """Return True if we are running from within the samba source tree"""
51     try:
52         topdir = source_tree_topdir()
53     except RuntimeError:
54         return False
55     return True
56
57
58 class Ldb(_Ldb):
59     """Simple Samba-specific LDB subclass that takes care
60     of setting up the modules dir, credentials pointers, etc.
61
62     Please note that this is intended to be for all Samba LDB files,
63     not necessarily the Sam database. For Sam-specific helper
64     functions see samdb.py.
65     """
66
67     def __init__(self, url=None, lp=None, modules_dir=None, session_info=None,
68                  credentials=None, flags=0, options=None):
69         """Opens a Samba Ldb file.
70
71         :param url: Optional LDB URL to open
72         :param lp: Optional loadparm object
73         :param modules_dir: Optional modules directory
74         :param session_info: Optional session information
75         :param credentials: Optional credentials, defaults to anonymous.
76         :param flags: Optional LDB flags
77         :param options: Additional options (optional)
78
79         This is different from a regular Ldb file in that the Samba-specific
80         modules-dir is used by default and that credentials and session_info
81         can be passed through (required by some modules).
82         """
83
84         if modules_dir is not None:
85             self.set_modules_dir(modules_dir)
86         else:
87             self.set_modules_dir(os.path.join(samba.param.modules_dir(), "ldb"))
88
89         if session_info is not None:
90             self.set_session_info(session_info)
91
92         if credentials is not None:
93             self.set_credentials(credentials)
94
95         if lp is not None:
96             self.set_loadparm(lp)
97
98         # This must be done before we load the schema, as these handlers for
99         # objectSid and objectGUID etc must take precedence over the 'binary
100         # attribute' declaration in the schema
101         self.register_samba_handlers()
102
103         # TODO set debug
104         def msg(l, text):
105             print(text)
106         #self.set_debug(msg)
107
108         self.set_utf8_casefold()
109
110         # Allow admins to force non-sync ldb for all databases
111         if lp is not None:
112             nosync_p = lp.get("nosync", "ldb")
113             if nosync_p is not None and nosync_p:
114                 flags |= ldb.FLG_NOSYNC
115
116         self.set_create_perms(0o600)
117
118         if url is not None:
119             self.connect(url, flags, options)
120
121     def searchone(self, attribute, basedn=None, expression=None,
122                   scope=ldb.SCOPE_BASE):
123         """Search for one attribute as a string.
124
125         :param basedn: BaseDN for the search.
126         :param attribute: Name of the attribute
127         :param expression: Optional search expression.
128         :param scope: Search scope (defaults to base).
129         :return: Value of attribute as a string or None if it wasn't found.
130         """
131         res = self.search(basedn, scope, expression, [attribute])
132         if len(res) != 1 or res[0][attribute] is None:
133             return None
134         values = set(res[0][attribute])
135         assert len(values) == 1
136         return self.schema_format_value(attribute, values.pop())
137
138     def erase_users_computers(self, dn):
139         """Erases user and computer objects from our AD.
140
141         This is needed since the 'samldb' module denies the deletion of primary
142         groups. Therefore all groups shouldn't be primary somewhere anymore.
143         """
144
145         try:
146             res = self.search(base=dn, scope=ldb.SCOPE_SUBTREE, attrs=[],
147                       expression="(|(objectclass=user)(objectclass=computer))")
148         except ldb.LdbError as error:
149             (errno, estr) = error.args
150             if errno == ldb.ERR_NO_SUCH_OBJECT:
151                 # Ignore no such object errors
152                 return
153             else:
154                 raise
155
156         try:
157             for msg in res:
158                 self.delete(msg.dn, ["relax:0"])
159         except ldb.LdbError as error:
160             (errno, estr) = error.args
161             if errno != ldb.ERR_NO_SUCH_OBJECT:
162                 # Ignore no such object errors
163                 raise
164
165     def erase_except_schema_controlled(self):
166         """Erase this ldb.
167
168         :note: Removes all records, except those that are controlled by
169             Samba4's schema.
170         """
171
172         basedn = ""
173
174         # Try to delete user/computer accounts to allow deletion of groups
175         self.erase_users_computers(basedn)
176
177         # Delete the 'visible' records, and the invisble 'deleted' records (if
178         # this DB supports it)
179         for msg in self.search(basedn, ldb.SCOPE_SUBTREE,
180                        "(&(|(objectclass=*)(distinguishedName=*))(!(distinguishedName=@BASEINFO)))",
181                        [], controls=["show_deleted:0", "show_recycled:0"]):
182             try:
183                 self.delete(msg.dn, ["relax:0"])
184             except ldb.LdbError as error:
185                 (errno, estr) = error.args
186                 if errno != ldb.ERR_NO_SUCH_OBJECT:
187                     # Ignore no such object errors
188                     raise
189
190         res = self.search(basedn, ldb.SCOPE_SUBTREE,
191             "(&(|(objectclass=*)(distinguishedName=*))(!(distinguishedName=@BASEINFO)))",
192             [], controls=["show_deleted:0", "show_recycled:0"])
193         assert len(res) == 0
194
195         # delete the specials
196         for attr in ["@SUBCLASSES", "@MODULES",
197                      "@OPTIONS", "@PARTITION", "@KLUDGEACL"]:
198             try:
199                 self.delete(attr, ["relax:0"])
200             except ldb.LdbError as error:
201                 (errno, estr) = error.args
202                 if errno != ldb.ERR_NO_SUCH_OBJECT:
203                     # Ignore missing dn errors
204                     raise
205
206     def erase(self):
207         """Erase this ldb, removing all records."""
208         self.erase_except_schema_controlled()
209
210         # delete the specials
211         for attr in ["@INDEXLIST", "@ATTRIBUTES"]:
212             try:
213                 self.delete(attr, ["relax:0"])
214             except ldb.LdbError as error:
215                 (errno, estr) = error.args
216                 if errno != ldb.ERR_NO_SUCH_OBJECT:
217                     # Ignore missing dn errors
218                     raise
219
220     def load_ldif_file_add(self, ldif_path):
221         """Load a LDIF file.
222
223         :param ldif_path: Path to LDIF file.
224         """
225         self.add_ldif(open(ldif_path, 'r').read())
226
227     def add_ldif(self, ldif, controls=None):
228         """Add data based on a LDIF string.
229
230         :param ldif: LDIF text.
231         """
232         for changetype, msg in self.parse_ldif(ldif):
233             assert changetype == ldb.CHANGETYPE_NONE
234             self.add(msg, controls)
235
236     def modify_ldif(self, ldif, controls=None):
237         """Modify database based on a LDIF string.
238
239         :param ldif: LDIF text.
240         """
241         for changetype, msg in self.parse_ldif(ldif):
242             if changetype == ldb.CHANGETYPE_ADD:
243                 self.add(msg, controls)
244             else:
245                 self.modify(msg, controls)
246
247
248 def substitute_var(text, values):
249     """Substitute strings of the form ${NAME} in str, replacing
250     with substitutions from values.
251
252     :param text: Text in which to subsitute.
253     :param values: Dictionary with keys and values.
254     """
255
256     for (name, value) in values.items():
257         assert isinstance(name, str), "%r is not a string" % name
258         assert isinstance(value, str), "Value %r for %s is not a string" % (value, name)
259         text = text.replace("${%s}" % name, value)
260
261     return text
262
263
264 def check_all_substituted(text):
265     """Check that all substitution variables in a string have been replaced.
266
267     If not, raise an exception.
268
269     :param text: The text to search for substitution variables
270     """
271     if not "${" in text:
272         return
273
274     var_start = text.find("${")
275     var_end = text.find("}", var_start)
276
277     raise Exception("Not all variables substituted: %s" %
278         text[var_start:var_end+1])
279
280
281 def read_and_sub_file(file_name, subst_vars):
282     """Read a file and sub in variables found in it
283
284     :param file_name: File to be read (typically from setup directory)
285      param subst_vars: Optional variables to subsitute in the file.
286     """
287     data = open(file_name, 'r').read()
288     if subst_vars is not None:
289         data = substitute_var(data, subst_vars)
290         check_all_substituted(data)
291     return data
292
293
294 def setup_file(template, fname, subst_vars=None):
295     """Setup a file in the private dir.
296
297     :param template: Path of the template file.
298     :param fname: Path of the file to create.
299     :param subst_vars: Substitution variables.
300     """
301     if os.path.exists(fname):
302         os.unlink(fname)
303
304     data = read_and_sub_file(template, subst_vars)
305     f = open(fname, 'w')
306     try:
307         f.write(data)
308     finally:
309         f.close()
310
311 MAX_NETBIOS_NAME_LEN = 15
312 def is_valid_netbios_char(c):
313     return (c.isalnum() or c in " !#$%&'()-.@^_{}~")
314
315
316 def valid_netbios_name(name):
317     """Check whether a name is valid as a NetBIOS name. """
318     # See crh's book (1.4.1.1)
319     if len(name) > MAX_NETBIOS_NAME_LEN:
320         return False
321     for x in name:
322         if not is_valid_netbios_char(x):
323             return False
324     return True
325
326
327 def import_bundled_package(modulename, location, source_tree_container,
328                            namespace):
329     """Import the bundled version of a package.
330
331     :note: This should only be called if the system version of the package
332         is not adequate.
333
334     :param modulename: Module name to import
335     :param location: Location to add to sys.path (can be relative to
336         ${srcdir}/${source_tree_container})
337     :param source_tree_container: Directory under source root that
338         contains the bundled third party modules.
339     :param namespace: Namespace to import module from, when not in source tree
340     """
341     if in_source_tree():
342         extra_path = os.path.join(source_tree_topdir(), source_tree_container,
343             location)
344         if not extra_path in sys.path:
345             sys.path.insert(0, extra_path)
346         sys.modules[modulename] = __import__(modulename)
347     else:
348         sys.modules[modulename] = __import__(
349             "%s.%s" % (namespace, modulename), fromlist=[namespace])
350
351
352 def ensure_third_party_module(modulename, location):
353     """Add a location to sys.path if a third party dependency can't be found.
354
355     :param modulename: Module name to import
356     :param location: Location to add to sys.path (can be relative to
357         ${srcdir}/third_party)
358     """
359     try:
360         __import__(modulename)
361     except ImportError:
362         import_bundled_package(modulename, location,
363             source_tree_container="third_party",
364             namespace="samba.third_party")
365
366
367 def dn_from_dns_name(dnsdomain):
368     """return a DN from a DNS name domain/forest root"""
369     return "DC=" + ",DC=".join(dnsdomain.split("."))
370
371 def current_unix_time():
372     return int(time.time())
373
374 def string_to_byte_array(string):
375     blob = [0] * len(string)
376
377     for i in range(len(string)):
378         blob[i] = ord(string[i])
379
380     return blob
381
382 def arcfour_encrypt(key, data):
383     try:
384         from Crypto.Cipher import ARC4
385         c = ARC4.new(key)
386         return c.encrypt(data)
387     except ImportError as e:
388         pass
389     try:
390         from M2Crypto.RC4 import RC4
391         c = RC4(key)
392         return c.update(data)
393     except ImportError as e:
394         pass
395     raise Exception("arcfour_encrypt() requires " +
396                     "python*-crypto or python*-m2crypto or m2crypto")
397
398 version = _glue.version
399 interface_ips = _glue.interface_ips
400 set_debug_level = _glue.set_debug_level
401 get_debug_level = _glue.get_debug_level
402 unix2nttime = _glue.unix2nttime
403 nttime2string = _glue.nttime2string
404 nttime2unix = _glue.nttime2unix
405 unix2nttime = _glue.unix2nttime
406 generate_random_password = _glue.generate_random_password
407 generate_random_machine_password = _glue.generate_random_machine_password
408 strcasecmp_m = _glue.strcasecmp_m
409 strstr_m = _glue.strstr_m
410 is_ntvfs_fileserver_built = _glue.is_ntvfs_fileserver_built
411
412 NTSTATUSError = _glue.NTSTATUSError
413 HRESULTError = _glue.HRESULTError
414 WERRORError = _glue.WERRORError
415 DsExtendedError = _glue.DsExtendedError