python/samba/tests: make sure samba-tool is called with ${PYTHON}
[amitay/samba.git] / python / samba / tests / __init__.py
1 # Unix SMB/CIFS implementation.
2 # Copyright (C) Jelmer Vernooij <jelmer@samba.org> 2007-2010
3 # Copyright (C) Stefan Metzmacher 2014,2015
4 #
5 # This program is free software; you can redistribute it and/or modify
6 # it under the terms of the GNU General Public License as published by
7 # the Free Software Foundation; either version 3 of the License, or
8 # (at your option) any later version.
9 #
10 # This program is distributed in the hope that it will be useful,
11 # but WITHOUT ANY WARRANTY; without even the implied warranty of
12 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 # GNU General Public License for more details.
14 #
15 # You should have received a copy of the GNU General Public License
16 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
17 #
18
19 """Samba Python tests."""
20
21 import os
22 import tempfile
23 import ldb
24 import samba
25 from samba import param
26 from samba import credentials
27 from samba.credentials import Credentials
28 from samba import gensec
29 import socket
30 import struct
31 import subprocess
32 import sys
33 import tempfile
34 import unittest
35 import re
36 import samba.auth
37 import samba.dcerpc.base
38 from samba.compat import PY3, text_type
39 from samba.compat import string_types
40 from random import randint
41 from random import SystemRandom
42 import string
43 try:
44     from samba.samdb import SamDB
45 except ImportError:
46     # We are built without samdb support,
47     # imitate it so that connect_samdb() can recover
48     def SamDB(*args, **kwargs):
49         return None
50
51 import samba.ndr
52 import samba.dcerpc.dcerpc
53 import samba.dcerpc.epmapper
54
55 try:
56     from unittest import SkipTest
57 except ImportError:
58     class SkipTest(Exception):
59         """Test skipped."""
60
61 HEXDUMP_FILTER = bytearray([x if ((len(repr(chr(x))) == 3) and (x < 127)) else ord('.') for x in range(256)])
62
63
64 class TestCase(unittest.TestCase):
65     """A Samba test case."""
66
67     def setUp(self):
68         super(TestCase, self).setUp()
69         test_debug_level = os.getenv("TEST_DEBUG_LEVEL")
70         if test_debug_level is not None:
71             test_debug_level = int(test_debug_level)
72             self._old_debug_level = samba.get_debug_level()
73             samba.set_debug_level(test_debug_level)
74             self.addCleanup(samba.set_debug_level, test_debug_level)
75
76     def get_loadparm(self):
77         return env_loadparm()
78
79     def get_credentials(self):
80         return cmdline_credentials
81
82     def get_creds_ccache_name(self):
83         creds = self.get_credentials()
84         ccache = creds.get_named_ccache(self.get_loadparm())
85         ccache_name = ccache.get_name()
86
87         return ccache_name
88
89     def hexdump(self, src):
90         N = 0
91         result = ''
92         is_string = isinstance(src, string_types)
93         while src:
94             ll = src[:8]
95             lr = src[8:16]
96             src = src[16:]
97             if is_string:
98                 hl = ' '.join(["%02X" % ord(x) for x in ll])
99                 hr = ' '.join(["%02X" % ord(x) for x in lr])
100                 ll = ll.translate(HEXDUMP_FILTER)
101                 lr = lr.translate(HEXDUMP_FILTER)
102             else:
103                 hl = ' '.join(["%02X" % x for x in ll])
104                 hr = ' '.join(["%02X" % x for x in lr])
105                 ll = ll.translate(HEXDUMP_FILTER).decode('utf8')
106                 lr = lr.translate(HEXDUMP_FILTER).decode('utf8')
107             result += "[%04X] %-*s  %-*s  %s %s\n" % (N, 8 * 3, hl, 8 * 3, hr, ll, lr)
108             N += 16
109         return result
110
111     def insta_creds(self, template=None, username=None, userpass=None, kerberos_state=None):
112
113         if template is None:
114             assert template is not None
115
116         if username is not None:
117             assert userpass is not None
118
119         if username is None:
120             assert userpass is None
121
122             username = template.get_username()
123             userpass = template.get_password()
124
125         if kerberos_state is None:
126             kerberos_state = template.get_kerberos_state()
127
128         # get a copy of the global creds or a the passed in creds
129         c = Credentials()
130         c.set_username(username)
131         c.set_password(userpass)
132         c.set_domain(template.get_domain())
133         c.set_realm(template.get_realm())
134         c.set_workstation(template.get_workstation())
135         c.set_gensec_features(c.get_gensec_features()
136                               | gensec.FEATURE_SEAL)
137         c.set_kerberos_state(kerberos_state)
138         return c
139
140     # These functions didn't exist before Python2.7:
141     if sys.version_info < (2, 7):
142         import warnings
143
144         def skipTest(self, reason):
145             raise SkipTest(reason)
146
147         def assertIn(self, member, container, msg=None):
148             self.assertTrue(member in container, msg)
149
150         def assertIs(self, a, b, msg=None):
151             self.assertTrue(a is b, msg)
152
153         def assertIsNot(self, a, b, msg=None):
154             self.assertTrue(a is not b, msg)
155
156         def assertIsNotNone(self, a, msg=None):
157             self.assertTrue(a is not None)
158
159         def assertIsInstance(self, a, b, msg=None):
160             self.assertTrue(isinstance(a, b), msg)
161
162         def assertIsNone(self, a, msg=None):
163             self.assertTrue(a is None, msg)
164
165         def assertGreater(self, a, b, msg=None):
166             self.assertTrue(a > b, msg)
167
168         def assertGreaterEqual(self, a, b, msg=None):
169             self.assertTrue(a >= b, msg)
170
171         def assertLess(self, a, b, msg=None):
172             self.assertTrue(a < b, msg)
173
174         def assertLessEqual(self, a, b, msg=None):
175             self.assertTrue(a <= b, msg)
176
177         def addCleanup(self, fn, *args, **kwargs):
178             self._cleanups = getattr(self, "_cleanups", []) + [
179                 (fn, args, kwargs)]
180
181         def assertRegexpMatches(self, text, regex, msg=None):
182             # PY3 note: Python 3 will never see this, but we use
183             # text_type for the benefit of linters.
184             if isinstance(regex, (str, text_type)):
185                 regex = re.compile(regex)
186             if not regex.search(text):
187                 self.fail(msg)
188
189         def _addSkip(self, result, reason):
190             addSkip = getattr(result, 'addSkip', None)
191             if addSkip is not None:
192                 addSkip(self, reason)
193             else:
194                 warnings.warn("TestResult has no addSkip method, skips not reported",
195                               RuntimeWarning, 2)
196                 result.addSuccess(self)
197
198         def run(self, result=None):
199             if result is None:
200                 result = self.defaultTestResult()
201             result.startTest(self)
202             testMethod = getattr(self, self._testMethodName)
203             try:
204                 try:
205                     self.setUp()
206                 except SkipTest as e:
207                     self._addSkip(result, str(e))
208                     return
209                 except KeyboardInterrupt:
210                     raise
211                 except:
212                     result.addError(self, self._exc_info())
213                     return
214
215                 ok = False
216                 try:
217                     testMethod()
218                     ok = True
219                 except SkipTest as e:
220                     self._addSkip(result, str(e))
221                     return
222                 except self.failureException:
223                     result.addFailure(self, self._exc_info())
224                 except KeyboardInterrupt:
225                     raise
226                 except:
227                     result.addError(self, self._exc_info())
228
229                 try:
230                     self.tearDown()
231                 except SkipTest as e:
232                     self._addSkip(result, str(e))
233                 except KeyboardInterrupt:
234                     raise
235                 except:
236                     result.addError(self, self._exc_info())
237                     ok = False
238
239                 for (fn, args, kwargs) in reversed(getattr(self, "_cleanups", [])):
240                     fn(*args, **kwargs)
241                 if ok:
242                     result.addSuccess(self)
243             finally:
244                 result.stopTest(self)
245
246     def assertStringsEqual(self, a, b, msg=None, strip=False):
247         """Assert equality between two strings and highlight any differences.
248         If strip is true, leading and trailing whitespace is ignored."""
249         if strip:
250             a = a.strip()
251             b = b.strip()
252
253         if a != b:
254             sys.stderr.write("The strings differ %s(lengths %d vs %d); "
255                              "a diff follows\n"
256                              % ('when stripped ' if strip else '',
257                                 len(a), len(b),
258                                 ))
259
260             from difflib import unified_diff
261             diff = unified_diff(a.splitlines(True),
262                                 b.splitlines(True),
263                                 'a', 'b')
264             for line in diff:
265                 sys.stderr.write(line)
266
267             self.fail(msg)
268
269
270 class LdbTestCase(TestCase):
271     """Trivial test case for running tests against a LDB."""
272
273     def setUp(self):
274         super(LdbTestCase, self).setUp()
275         self.tempfile = tempfile.NamedTemporaryFile(delete=False)
276         self.filename = self.tempfile.name
277         self.ldb = samba.Ldb(self.filename)
278
279     def set_modules(self, modules=[]):
280         """Change the modules for this Ldb."""
281         m = ldb.Message()
282         m.dn = ldb.Dn(self.ldb, "@MODULES")
283         m["@LIST"] = ",".join(modules)
284         self.ldb.add(m)
285         self.ldb = samba.Ldb(self.filename)
286
287
288 class TestCaseInTempDir(TestCase):
289
290     def setUp(self):
291         super(TestCaseInTempDir, self).setUp()
292         self.tempdir = tempfile.mkdtemp()
293         self.addCleanup(self._remove_tempdir)
294
295     def _remove_tempdir(self):
296         self.assertEquals([], os.listdir(self.tempdir))
297         os.rmdir(self.tempdir)
298         self.tempdir = None
299
300
301 def env_loadparm():
302     lp = param.LoadParm()
303     try:
304         lp.load(os.environ["SMB_CONF_PATH"])
305     except KeyError:
306         raise KeyError("SMB_CONF_PATH not set")
307     return lp
308
309
310 def env_get_var_value(var_name, allow_missing=False):
311     """Returns value for variable in os.environ
312
313     Function throws AssertionError if variable is defined.
314     Unit-test based python tests require certain input params
315     to be set in environment, otherwise they can't be run
316     """
317     if allow_missing:
318         if var_name not in os.environ.keys():
319             return None
320     assert var_name in os.environ.keys(), "Please supply %s in environment" % var_name
321     return os.environ[var_name]
322
323
324 cmdline_credentials = None
325
326
327 class RpcInterfaceTestCase(TestCase):
328     """DCE/RPC Test case."""
329
330
331 class ValidNetbiosNameTests(TestCase):
332
333     def test_valid(self):
334         self.assertTrue(samba.valid_netbios_name("FOO"))
335
336     def test_too_long(self):
337         self.assertFalse(samba.valid_netbios_name("FOO" * 10))
338
339     def test_invalid_characters(self):
340         self.assertFalse(samba.valid_netbios_name("*BLA"))
341
342
343 class BlackboxProcessError(Exception):
344     """This is raised when check_output() process returns a non-zero exit status
345
346     Exception instance should contain the exact exit code (S.returncode),
347     command line (S.cmd), process output (S.stdout) and process error stream
348     (S.stderr)
349     """
350
351     def __init__(self, returncode, cmd, stdout, stderr, msg=None):
352         self.returncode = returncode
353         self.cmd = cmd
354         self.stdout = stdout
355         self.stderr = stderr
356         self.msg = msg
357
358     def __str__(self):
359         s = ("Command '%s'; exit status %d; stdout: '%s'; stderr: '%s'" %
360              (self.cmd, self.returncode, self.stdout, self.stderr))
361         if self.msg is not None:
362             s = "%s; message: %s" % (s, self.msg)
363
364         return s
365
366
367 class BlackboxTestCase(TestCaseInTempDir):
368     """Base test case for blackbox tests."""
369
370     def _make_cmdline(self, line):
371         bindir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../../bin"))
372         parts = line.split(" ")
373         if os.path.exists(os.path.join(bindir, parts[0])):
374             cmd = parts[0]
375             parts[0] = os.path.join(bindir, parts[0])
376             if cmd == "samba-tool" and os.getenv("PYTHON", None):
377                 parts = [os.environ["PYTHON"]] + parts
378         line = " ".join(parts)
379         return line
380
381     def check_run(self, line, msg=None):
382         self.check_exit_code(line, 0, msg=msg)
383
384     def check_exit_code(self, line, expected, msg=None):
385         line = self._make_cmdline(line)
386         p = subprocess.Popen(line,
387                              stdout=subprocess.PIPE,
388                              stderr=subprocess.PIPE,
389                              shell=True)
390         stdoutdata, stderrdata = p.communicate()
391         retcode = p.returncode
392         if retcode != expected:
393             raise BlackboxProcessError(retcode,
394                                        line,
395                                        stdoutdata,
396                                        stderrdata,
397                                        msg)
398
399     def check_output(self, line):
400         line = self._make_cmdline(line)
401         p = subprocess.Popen(line, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True, close_fds=True)
402         stdoutdata, stderrdata = p.communicate()
403         retcode = p.returncode
404         if retcode:
405             raise BlackboxProcessError(retcode, line, stdoutdata, stderrdata)
406         return stdoutdata
407
408     # Generate a random password that can be safely  passed on the command line
409     # i.e. it does not contain any shell meta characters.
410     def random_password(self, count=32):
411         password = SystemRandom().choice(string.ascii_uppercase)
412         password += SystemRandom().choice(string.digits)
413         password += SystemRandom().choice(string.ascii_lowercase)
414         password += ''.join(SystemRandom().choice(string.ascii_uppercase +
415                     string.ascii_lowercase +
416                     string.digits) for x in range(count - 3))
417         return password
418
419
420 def connect_samdb(samdb_url, lp=None, session_info=None, credentials=None,
421                   flags=0, ldb_options=None, ldap_only=False, global_schema=True):
422     """Create SamDB instance and connects to samdb_url database.
423
424     :param samdb_url: Url for database to connect to.
425     :param lp: Optional loadparm object
426     :param session_info: Optional session information
427     :param credentials: Optional credentials, defaults to anonymous.
428     :param flags: Optional LDB flags
429     :param ldap_only: If set, only remote LDAP connection will be created.
430     :param global_schema: Whether to use global schema.
431
432     Added value for tests is that we have a shorthand function
433     to make proper URL for ldb.connect() while using default
434     parameters for connection based on test environment
435     """
436     if "://" not in samdb_url:
437         if not ldap_only and os.path.isfile(samdb_url):
438             samdb_url = "tdb://%s" % samdb_url
439         else:
440             samdb_url = "ldap://%s" % samdb_url
441     # use 'paged_search' module when connecting remotely
442     if samdb_url.startswith("ldap://"):
443         ldb_options = ["modules:paged_searches"]
444     elif ldap_only:
445         raise AssertionError("Trying to connect to %s while remote "
446                              "connection is required" % samdb_url)
447
448     # set defaults for test environment
449     if lp is None:
450         lp = env_loadparm()
451     if session_info is None:
452         session_info = samba.auth.system_session(lp)
453     if credentials is None:
454         credentials = cmdline_credentials
455
456     return SamDB(url=samdb_url,
457                  lp=lp,
458                  session_info=session_info,
459                  credentials=credentials,
460                  flags=flags,
461                  options=ldb_options,
462                  global_schema=global_schema)
463
464
465 def connect_samdb_ex(samdb_url, lp=None, session_info=None, credentials=None,
466                      flags=0, ldb_options=None, ldap_only=False):
467     """Connects to samdb_url database
468
469     :param samdb_url: Url for database to connect to.
470     :param lp: Optional loadparm object
471     :param session_info: Optional session information
472     :param credentials: Optional credentials, defaults to anonymous.
473     :param flags: Optional LDB flags
474     :param ldap_only: If set, only remote LDAP connection will be created.
475     :return: (sam_db_connection, rootDse_record) tuple
476     """
477     sam_db = connect_samdb(samdb_url, lp, session_info, credentials,
478                            flags, ldb_options, ldap_only)
479     # fetch RootDse
480     res = sam_db.search(base="", expression="", scope=ldb.SCOPE_BASE,
481                         attrs=["*"])
482     return (sam_db, res[0])
483
484
485 def connect_samdb_env(env_url, env_username, env_password, lp=None):
486     """Connect to SamDB by getting URL and Credentials from environment
487
488     :param env_url: Environment variable name to get lsb url from
489     :param env_username: Username environment variable
490     :param env_password: Password environment variable
491     :return: sam_db_connection
492     """
493     samdb_url = env_get_var_value(env_url)
494     creds = credentials.Credentials()
495     if lp is None:
496         # guess Credentials parameters here. Otherwise workstation
497         # and domain fields are NULL and gencache code segfalts
498         lp = param.LoadParm()
499         creds.guess(lp)
500     creds.set_username(env_get_var_value(env_username))
501     creds.set_password(env_get_var_value(env_password))
502     return connect_samdb(samdb_url, credentials=creds, lp=lp)
503
504
505 def delete_force(samdb, dn, **kwargs):
506     try:
507         samdb.delete(dn, **kwargs)
508     except ldb.LdbError as error:
509         (num, errstr) = error.args
510         assert num == ldb.ERR_NO_SUCH_OBJECT, "ldb.delete() failed: %s" % errstr
511
512
513 def create_test_ou(samdb, name):
514     """Creates a unique OU for the test"""
515
516     # Add some randomness to the test OU. Replication between the testenvs is
517     # constantly happening in the background. Deletion of the last test's
518     # objects can be slow to replicate out. So the OU created by a previous
519     # testenv may still exist at the point that tests start on another testenv.
520     rand = randint(1, 10000000)
521     dn = ldb.Dn(samdb, "OU=%s%d,%s" % (name, rand, samdb.get_default_basedn()))
522     samdb.add({"dn": dn, "objectclass": "organizationalUnit"})
523     return dn