python: samba.tests: Move import of ported modules out of PY3 condition
[sfrench/samba-autobuild/.git] / python / samba / tests / __init__.py
1 # Unix SMB/CIFS implementation.
2 # Copyright (C) Jelmer Vernooij <jelmer@samba.org> 2007-2010
3 # Copyright (C) Stefan Metzmacher 2014,2015
4 #
5 # This program is free software; you can redistribute it and/or modify
6 # it under the terms of the GNU General Public License as published by
7 # the Free Software Foundation; either version 3 of the License, or
8 # (at your option) any later version.
9 #
10 # This program is distributed in the hope that it will be useful,
11 # but WITHOUT ANY WARRANTY; without even the implied warranty of
12 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 # GNU General Public License for more details.
14 #
15 # You should have received a copy of the GNU General Public License
16 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
17 #
18
19 """Samba Python tests."""
20
21 import os
22 import ldb
23 import samba
24 from samba import param
25 from samba import credentials
26 from samba.credentials import Credentials
27 import socket
28 import struct
29 import subprocess
30 import sys
31 import tempfile
32 import unittest
33 import samba.auth
34 import samba.dcerpc.base
35 from samba.compat import PY3
36 if not PY3:
37     # Py2 only
38     from samba.samdb import SamDB
39     import samba.ndr
40     import samba.dcerpc.dcerpc
41     import samba.dcerpc.epmapper
42     from samba import gensec
43
44 try:
45     from unittest import SkipTest
46 except ImportError:
47     class SkipTest(Exception):
48         """Test skipped."""
49
50 HEXDUMP_FILTER=''.join([(len(repr(chr(x)))==3) and chr(x) or '.' for x in range(256)])
51
52 class TestCase(unittest.TestCase):
53     """A Samba test case."""
54
55     def setUp(self):
56         super(TestCase, self).setUp()
57         test_debug_level = os.getenv("TEST_DEBUG_LEVEL")
58         if test_debug_level is not None:
59             test_debug_level = int(test_debug_level)
60             self._old_debug_level = samba.get_debug_level()
61             samba.set_debug_level(test_debug_level)
62             self.addCleanup(samba.set_debug_level, test_debug_level)
63
64     def get_loadparm(self):
65         return env_loadparm()
66
67     def get_credentials(self):
68         return cmdline_credentials
69
70     def hexdump(self, src):
71         N = 0
72         result = ''
73         while src:
74             ll = src[:8]
75             lr = src[8:16]
76             src = src[16:]
77             hl = ' '.join(["%02X" % ord(x) for x in ll])
78             hr = ' '.join(["%02X" % ord(x) for x in lr])
79             ll = ll.translate(HEXDUMP_FILTER)
80             lr = lr.translate(HEXDUMP_FILTER)
81             result += "[%04X] %-*s  %-*s  %s %s\n" % (N, 8*3, hl, 8*3, hr, ll, lr)
82             N += 16
83         return result
84
85     # These functions didn't exist before Python2.7:
86     if sys.version_info < (2, 7):
87         import warnings
88
89         def skipTest(self, reason):
90             raise SkipTest(reason)
91
92         def assertIn(self, member, container, msg=None):
93             self.assertTrue(member in container, msg)
94
95         def assertIs(self, a, b, msg=None):
96             self.assertTrue(a is b, msg)
97
98         def assertIsNot(self, a, b, msg=None):
99             self.assertTrue(a is not b, msg)
100
101         def assertIsNotNone(self, a, msg=None):
102             self.assertTrue(a is not None)
103
104         def assertIsInstance(self, a, b, msg=None):
105             self.assertTrue(isinstance(a, b), msg)
106
107         def assertIsNone(self, a, msg=None):
108             self.assertTrue(a is None, msg)
109
110         def assertGreater(self, a, b, msg=None):
111             self.assertTrue(a > b, msg)
112
113         def assertGreaterEqual(self, a, b, msg=None):
114             self.assertTrue(a >= b, msg)
115
116         def assertLess(self, a, b, msg=None):
117             self.assertTrue(a < b, msg)
118
119         def assertLessEqual(self, a, b, msg=None):
120             self.assertTrue(a <= b, msg)
121
122         def addCleanup(self, fn, *args, **kwargs):
123             self._cleanups = getattr(self, "_cleanups", []) + [
124                 (fn, args, kwargs)]
125
126         def _addSkip(self, result, reason):
127             addSkip = getattr(result, 'addSkip', None)
128             if addSkip is not None:
129                 addSkip(self, reason)
130             else:
131                 warnings.warn("TestResult has no addSkip method, skips not reported",
132                               RuntimeWarning, 2)
133                 result.addSuccess(self)
134
135         def run(self, result=None):
136             if result is None: result = self.defaultTestResult()
137             result.startTest(self)
138             testMethod = getattr(self, self._testMethodName)
139             try:
140                 try:
141                     self.setUp()
142                 except SkipTest as e:
143                     self._addSkip(result, str(e))
144                     return
145                 except KeyboardInterrupt:
146                     raise
147                 except:
148                     result.addError(self, self._exc_info())
149                     return
150
151                 ok = False
152                 try:
153                     testMethod()
154                     ok = True
155                 except SkipTest as e:
156                     self._addSkip(result, str(e))
157                     return
158                 except self.failureException:
159                     result.addFailure(self, self._exc_info())
160                 except KeyboardInterrupt:
161                     raise
162                 except:
163                     result.addError(self, self._exc_info())
164
165                 try:
166                     self.tearDown()
167                 except SkipTest as e:
168                     self._addSkip(result, str(e))
169                 except KeyboardInterrupt:
170                     raise
171                 except:
172                     result.addError(self, self._exc_info())
173                     ok = False
174
175                 for (fn, args, kwargs) in reversed(getattr(self, "_cleanups", [])):
176                     fn(*args, **kwargs)
177                 if ok: result.addSuccess(self)
178             finally:
179                 result.stopTest(self)
180
181
182 class LdbTestCase(TestCase):
183     """Trivial test case for running tests against a LDB."""
184
185     def setUp(self):
186         super(LdbTestCase, self).setUp()
187         self.filename = os.tempnam()
188         self.ldb = samba.Ldb(self.filename)
189
190     def set_modules(self, modules=[]):
191         """Change the modules for this Ldb."""
192         m = ldb.Message()
193         m.dn = ldb.Dn(self.ldb, "@MODULES")
194         m["@LIST"] = ",".join(modules)
195         self.ldb.add(m)
196         self.ldb = samba.Ldb(self.filename)
197
198
199 class TestCaseInTempDir(TestCase):
200
201     def setUp(self):
202         super(TestCaseInTempDir, self).setUp()
203         self.tempdir = tempfile.mkdtemp()
204         self.addCleanup(self._remove_tempdir)
205
206     def _remove_tempdir(self):
207         self.assertEquals([], os.listdir(self.tempdir))
208         os.rmdir(self.tempdir)
209         self.tempdir = None
210
211
212 def env_loadparm():
213     lp = param.LoadParm()
214     try:
215         lp.load(os.environ["SMB_CONF_PATH"])
216     except KeyError:
217         raise KeyError("SMB_CONF_PATH not set")
218     return lp
219
220
221 def env_get_var_value(var_name, allow_missing=False):
222     """Returns value for variable in os.environ
223
224     Function throws AssertionError if variable is defined.
225     Unit-test based python tests require certain input params
226     to be set in environment, otherwise they can't be run
227     """
228     if allow_missing:
229         if var_name not in os.environ.keys():
230             return None
231     assert var_name in os.environ.keys(), "Please supply %s in environment" % var_name
232     return os.environ[var_name]
233
234
235 cmdline_credentials = None
236
237 class RpcInterfaceTestCase(TestCase):
238     """DCE/RPC Test case."""
239
240
241 class ValidNetbiosNameTests(TestCase):
242
243     def test_valid(self):
244         self.assertTrue(samba.valid_netbios_name("FOO"))
245
246     def test_too_long(self):
247         self.assertFalse(samba.valid_netbios_name("FOO"*10))
248
249     def test_invalid_characters(self):
250         self.assertFalse(samba.valid_netbios_name("*BLA"))
251
252
253 class BlackboxProcessError(Exception):
254     """This is raised when check_output() process returns a non-zero exit status
255
256     Exception instance should contain the exact exit code (S.returncode),
257     command line (S.cmd), process output (S.stdout) and process error stream
258     (S.stderr)
259     """
260
261     def __init__(self, returncode, cmd, stdout, stderr):
262         self.returncode = returncode
263         self.cmd = cmd
264         self.stdout = stdout
265         self.stderr = stderr
266
267     def __str__(self):
268         return "Command '%s'; exit status %d; stdout: '%s'; stderr: '%s'" % (self.cmd, self.returncode,
269                                                                              self.stdout, self.stderr)
270
271 class BlackboxTestCase(TestCaseInTempDir):
272     """Base test case for blackbox tests."""
273
274     def _make_cmdline(self, line):
275         bindir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../../bin"))
276         parts = line.split(" ")
277         if os.path.exists(os.path.join(bindir, parts[0])):
278             parts[0] = os.path.join(bindir, parts[0])
279         line = " ".join(parts)
280         return line
281
282     def check_run(self, line):
283         line = self._make_cmdline(line)
284         p = subprocess.Popen(line, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
285         retcode = p.wait()
286         if retcode:
287             raise BlackboxProcessError(retcode, line, p.stdout.read(), p.stderr.read())
288
289     def check_output(self, line):
290         line = self._make_cmdline(line)
291         p = subprocess.Popen(line, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True, close_fds=True)
292         retcode = p.wait()
293         if retcode:
294             raise BlackboxProcessError(retcode, line, p.stdout.read(), p.stderr.read())
295         return p.stdout.read()
296
297
298 def connect_samdb(samdb_url, lp=None, session_info=None, credentials=None,
299                   flags=0, ldb_options=None, ldap_only=False, global_schema=True):
300     """Create SamDB instance and connects to samdb_url database.
301
302     :param samdb_url: Url for database to connect to.
303     :param lp: Optional loadparm object
304     :param session_info: Optional session information
305     :param credentials: Optional credentials, defaults to anonymous.
306     :param flags: Optional LDB flags
307     :param ldap_only: If set, only remote LDAP connection will be created.
308     :param global_schema: Whether to use global schema.
309
310     Added value for tests is that we have a shorthand function
311     to make proper URL for ldb.connect() while using default
312     parameters for connection based on test environment
313     """
314     if not "://" in samdb_url:
315         if not ldap_only and os.path.isfile(samdb_url):
316             samdb_url = "tdb://%s" % samdb_url
317         else:
318             samdb_url = "ldap://%s" % samdb_url
319     # use 'paged_search' module when connecting remotely
320     if samdb_url.startswith("ldap://"):
321         ldb_options = ["modules:paged_searches"]
322     elif ldap_only:
323         raise AssertionError("Trying to connect to %s while remote "
324                              "connection is required" % samdb_url)
325
326     # set defaults for test environment
327     if lp is None:
328         lp = env_loadparm()
329     if session_info is None:
330         session_info = samba.auth.system_session(lp)
331     if credentials is None:
332         credentials = cmdline_credentials
333
334     return SamDB(url=samdb_url,
335                  lp=lp,
336                  session_info=session_info,
337                  credentials=credentials,
338                  flags=flags,
339                  options=ldb_options,
340                  global_schema=global_schema)
341
342
343 def connect_samdb_ex(samdb_url, lp=None, session_info=None, credentials=None,
344                      flags=0, ldb_options=None, ldap_only=False):
345     """Connects to samdb_url database
346
347     :param samdb_url: Url for database to connect to.
348     :param lp: Optional loadparm object
349     :param session_info: Optional session information
350     :param credentials: Optional credentials, defaults to anonymous.
351     :param flags: Optional LDB flags
352     :param ldap_only: If set, only remote LDAP connection will be created.
353     :return: (sam_db_connection, rootDse_record) tuple
354     """
355     sam_db = connect_samdb(samdb_url, lp, session_info, credentials,
356                            flags, ldb_options, ldap_only)
357     # fetch RootDse
358     res = sam_db.search(base="", expression="", scope=ldb.SCOPE_BASE,
359                         attrs=["*"])
360     return (sam_db, res[0])
361
362
363 def connect_samdb_env(env_url, env_username, env_password, lp=None):
364     """Connect to SamDB by getting URL and Credentials from environment
365
366     :param env_url: Environment variable name to get lsb url from
367     :param env_username: Username environment variable
368     :param env_password: Password environment variable
369     :return: sam_db_connection
370     """
371     samdb_url = env_get_var_value(env_url)
372     creds = credentials.Credentials()
373     if lp is None:
374         # guess Credentials parameters here. Otherwise workstation
375         # and domain fields are NULL and gencache code segfalts
376         lp = param.LoadParm()
377         creds.guess(lp)
378     creds.set_username(env_get_var_value(env_username))
379     creds.set_password(env_get_var_value(env_password))
380     return connect_samdb(samdb_url, credentials=creds, lp=lp)
381
382
383 def delete_force(samdb, dn):
384     try:
385         samdb.delete(dn)
386     except ldb.LdbError as error:
387         (num, errstr) = error.args
388         assert num == ldb.ERR_NO_SUCH_OBJECT, "ldb.delete() failed: %s" % errstr