b6293dff75707c1f7e98cf4f42d572771b515496
[samba.git] / python / samba / tests / __init__.py
1 # Unix SMB/CIFS implementation.
2 # Copyright (C) Jelmer Vernooij <jelmer@samba.org> 2007-2010
3 # Copyright (C) Stefan Metzmacher 2014,2015
4 #
5 # This program is free software; you can redistribute it and/or modify
6 # it under the terms of the GNU General Public License as published by
7 # the Free Software Foundation; either version 3 of the License, or
8 # (at your option) any later version.
9 #
10 # This program is distributed in the hope that it will be useful,
11 # but WITHOUT ANY WARRANTY; without even the implied warranty of
12 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 # GNU General Public License for more details.
14 #
15 # You should have received a copy of the GNU General Public License
16 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
17 #
18
19 """Samba Python tests."""
20
21 import os
22 import ldb
23 import samba
24 from samba import param
25 from samba import credentials
26 from samba.credentials import Credentials
27 from samba import gensec
28 import socket
29 import struct
30 import subprocess
31 import sys
32 import tempfile
33 import unittest
34 import re
35 import samba.auth
36 import samba.dcerpc.base
37 from samba.compat import PY3, text_type
38 if not PY3:
39     # Py2 only
40     from samba.samdb import SamDB
41     import samba.ndr
42     import samba.dcerpc.dcerpc
43     import samba.dcerpc.epmapper
44
45 try:
46     from unittest import SkipTest
47 except ImportError:
48     class SkipTest(Exception):
49         """Test skipped."""
50
51 HEXDUMP_FILTER=''.join([(len(repr(chr(x)))==3) and chr(x) or '.' for x in range(256)])
52
53 class TestCase(unittest.TestCase):
54     """A Samba test case."""
55
56     def setUp(self):
57         super(TestCase, self).setUp()
58         test_debug_level = os.getenv("TEST_DEBUG_LEVEL")
59         if test_debug_level is not None:
60             test_debug_level = int(test_debug_level)
61             self._old_debug_level = samba.get_debug_level()
62             samba.set_debug_level(test_debug_level)
63             self.addCleanup(samba.set_debug_level, test_debug_level)
64
65     def get_loadparm(self):
66         return env_loadparm()
67
68     def get_credentials(self):
69         return cmdline_credentials
70
71     def get_creds_ccache_name(self):
72         creds = self.get_credentials()
73         ccache = creds.get_named_ccache(self.get_loadparm())
74         ccache_name = ccache.get_name()
75
76         return ccache_name
77
78     def hexdump(self, src):
79         N = 0
80         result = ''
81         while src:
82             ll = src[:8]
83             lr = src[8:16]
84             src = src[16:]
85             hl = ' '.join(["%02X" % ord(x) for x in ll])
86             hr = ' '.join(["%02X" % ord(x) for x in lr])
87             ll = ll.translate(HEXDUMP_FILTER)
88             lr = lr.translate(HEXDUMP_FILTER)
89             result += "[%04X] %-*s  %-*s  %s %s\n" % (N, 8*3, hl, 8*3, hr, ll, lr)
90             N += 16
91         return result
92
93     def insta_creds(self, template=None, username=None, userpass=None, kerberos_state=None):
94
95         if template is None:
96             assert template is not None
97
98         if username is not None:
99             assert userpass is not None
100
101         if username is None:
102             assert userpass is None
103
104             username = template.get_username()
105             userpass = template.get_password()
106
107         if kerberos_state is None:
108             kerberos_state = template.get_kerberos_state()
109
110         # get a copy of the global creds or a the passed in creds
111         c = Credentials()
112         c.set_username(username)
113         c.set_password(userpass)
114         c.set_domain(template.get_domain())
115         c.set_realm(template.get_realm())
116         c.set_workstation(template.get_workstation())
117         c.set_gensec_features(c.get_gensec_features()
118                               | gensec.FEATURE_SEAL)
119         c.set_kerberos_state(kerberos_state)
120         return c
121
122
123
124     # These functions didn't exist before Python2.7:
125     if sys.version_info < (2, 7):
126         import warnings
127
128         def skipTest(self, reason):
129             raise SkipTest(reason)
130
131         def assertIn(self, member, container, msg=None):
132             self.assertTrue(member in container, msg)
133
134         def assertIs(self, a, b, msg=None):
135             self.assertTrue(a is b, msg)
136
137         def assertIsNot(self, a, b, msg=None):
138             self.assertTrue(a is not b, msg)
139
140         def assertIsNotNone(self, a, msg=None):
141             self.assertTrue(a is not None)
142
143         def assertIsInstance(self, a, b, msg=None):
144             self.assertTrue(isinstance(a, b), msg)
145
146         def assertIsNone(self, a, msg=None):
147             self.assertTrue(a is None, msg)
148
149         def assertGreater(self, a, b, msg=None):
150             self.assertTrue(a > b, msg)
151
152         def assertGreaterEqual(self, a, b, msg=None):
153             self.assertTrue(a >= b, msg)
154
155         def assertLess(self, a, b, msg=None):
156             self.assertTrue(a < b, msg)
157
158         def assertLessEqual(self, a, b, msg=None):
159             self.assertTrue(a <= b, msg)
160
161         def addCleanup(self, fn, *args, **kwargs):
162             self._cleanups = getattr(self, "_cleanups", []) + [
163                 (fn, args, kwargs)]
164
165         def assertRegexpMatches(self, text, regex, msg=None):
166             # PY3 note: Python 3 will never see this, but we use
167             # text_type for the benefit of linters.
168             if isinstance(regex, (str, text_type)):
169                 regex = re.compile(regex)
170             if not regex.search(text):
171                 self.fail(msg)
172
173         def _addSkip(self, result, reason):
174             addSkip = getattr(result, 'addSkip', None)
175             if addSkip is not None:
176                 addSkip(self, reason)
177             else:
178                 warnings.warn("TestResult has no addSkip method, skips not reported",
179                               RuntimeWarning, 2)
180                 result.addSuccess(self)
181
182         def run(self, result=None):
183             if result is None: result = self.defaultTestResult()
184             result.startTest(self)
185             testMethod = getattr(self, self._testMethodName)
186             try:
187                 try:
188                     self.setUp()
189                 except SkipTest as e:
190                     self._addSkip(result, str(e))
191                     return
192                 except KeyboardInterrupt:
193                     raise
194                 except:
195                     result.addError(self, self._exc_info())
196                     return
197
198                 ok = False
199                 try:
200                     testMethod()
201                     ok = True
202                 except SkipTest as e:
203                     self._addSkip(result, str(e))
204                     return
205                 except self.failureException:
206                     result.addFailure(self, self._exc_info())
207                 except KeyboardInterrupt:
208                     raise
209                 except:
210                     result.addError(self, self._exc_info())
211
212                 try:
213                     self.tearDown()
214                 except SkipTest as e:
215                     self._addSkip(result, str(e))
216                 except KeyboardInterrupt:
217                     raise
218                 except:
219                     result.addError(self, self._exc_info())
220                     ok = False
221
222                 for (fn, args, kwargs) in reversed(getattr(self, "_cleanups", [])):
223                     fn(*args, **kwargs)
224                 if ok: result.addSuccess(self)
225             finally:
226                 result.stopTest(self)
227
228     def assertStringsEqual(self, a, b, msg=None, strip=False):
229         """Assert equality between two strings and highlight any differences.
230         If strip is true, leading and trailing whitespace is ignored."""
231         if strip:
232             a = a.strip()
233             b = b.strip()
234
235         if a != b:
236             sys.stderr.write("The strings differ %s(lengths %d vs %d); "
237                              "a diff follows\n"
238                              % ('when stripped ' if strip else '',
239                                 len(a), len(b),
240                              ))
241
242             from difflib import unified_diff
243             diff = unified_diff(a.splitlines(True),
244                                 b.splitlines(True),
245                                 'a', 'b')
246             for line in diff:
247                 sys.stderr.write(line)
248
249             self.fail(msg)
250
251
252 class LdbTestCase(TestCase):
253     """Trivial test case for running tests against a LDB."""
254
255     def setUp(self):
256         super(LdbTestCase, self).setUp()
257         self.filename = os.tempnam()
258         self.ldb = samba.Ldb(self.filename)
259
260     def set_modules(self, modules=[]):
261         """Change the modules for this Ldb."""
262         m = ldb.Message()
263         m.dn = ldb.Dn(self.ldb, "@MODULES")
264         m["@LIST"] = ",".join(modules)
265         self.ldb.add(m)
266         self.ldb = samba.Ldb(self.filename)
267
268
269 class TestCaseInTempDir(TestCase):
270
271     def setUp(self):
272         super(TestCaseInTempDir, self).setUp()
273         self.tempdir = tempfile.mkdtemp()
274         self.addCleanup(self._remove_tempdir)
275
276     def _remove_tempdir(self):
277         self.assertEquals([], os.listdir(self.tempdir))
278         os.rmdir(self.tempdir)
279         self.tempdir = None
280
281
282 def env_loadparm():
283     lp = param.LoadParm()
284     try:
285         lp.load(os.environ["SMB_CONF_PATH"])
286     except KeyError:
287         raise KeyError("SMB_CONF_PATH not set")
288     return lp
289
290
291 def env_get_var_value(var_name, allow_missing=False):
292     """Returns value for variable in os.environ
293
294     Function throws AssertionError if variable is defined.
295     Unit-test based python tests require certain input params
296     to be set in environment, otherwise they can't be run
297     """
298     if allow_missing:
299         if var_name not in os.environ.keys():
300             return None
301     assert var_name in os.environ.keys(), "Please supply %s in environment" % var_name
302     return os.environ[var_name]
303
304
305 cmdline_credentials = None
306
307 class RpcInterfaceTestCase(TestCase):
308     """DCE/RPC Test case."""
309
310
311 class ValidNetbiosNameTests(TestCase):
312
313     def test_valid(self):
314         self.assertTrue(samba.valid_netbios_name("FOO"))
315
316     def test_too_long(self):
317         self.assertFalse(samba.valid_netbios_name("FOO"*10))
318
319     def test_invalid_characters(self):
320         self.assertFalse(samba.valid_netbios_name("*BLA"))
321
322
323 class BlackboxProcessError(Exception):
324     """This is raised when check_output() process returns a non-zero exit status
325
326     Exception instance should contain the exact exit code (S.returncode),
327     command line (S.cmd), process output (S.stdout) and process error stream
328     (S.stderr)
329     """
330
331     def __init__(self, returncode, cmd, stdout, stderr):
332         self.returncode = returncode
333         self.cmd = cmd
334         self.stdout = stdout
335         self.stderr = stderr
336
337     def __str__(self):
338         return "Command '%s'; exit status %d; stdout: '%s'; stderr: '%s'" % (self.cmd, self.returncode,
339                                                                              self.stdout, self.stderr)
340
341 class BlackboxTestCase(TestCaseInTempDir):
342     """Base test case for blackbox tests."""
343
344     def _make_cmdline(self, line):
345         bindir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../../bin"))
346         parts = line.split(" ")
347         if os.path.exists(os.path.join(bindir, parts[0])):
348             parts[0] = os.path.join(bindir, parts[0])
349         line = " ".join(parts)
350         return line
351
352     def check_run(self, line):
353         self.check_exit_code(line, 0)
354
355     def check_exit_code(self, line, expected):
356         line = self._make_cmdline(line)
357         p = subprocess.Popen(line,
358                              stdout=subprocess.PIPE,
359                              stderr=subprocess.PIPE,
360                              shell=True)
361         stdoutdata, stderrdata = p.communicate()
362         retcode = p.returncode
363         if retcode != expected:
364             raise BlackboxProcessError(retcode,
365                                        line,
366                                        stdoutdata,
367                                        stderrdata)
368
369     def check_output(self, line):
370         line = self._make_cmdline(line)
371         p = subprocess.Popen(line, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True, close_fds=True)
372         stdoutdata, stderrdata = p.communicate()
373         retcode = p.returncode
374         if retcode:
375             raise BlackboxProcessError(retcode, line, stdoutdata, stderrdata)
376         return stdoutdata
377
378
379 def connect_samdb(samdb_url, lp=None, session_info=None, credentials=None,
380                   flags=0, ldb_options=None, ldap_only=False, global_schema=True):
381     """Create SamDB instance and connects to samdb_url database.
382
383     :param samdb_url: Url for database to connect to.
384     :param lp: Optional loadparm object
385     :param session_info: Optional session information
386     :param credentials: Optional credentials, defaults to anonymous.
387     :param flags: Optional LDB flags
388     :param ldap_only: If set, only remote LDAP connection will be created.
389     :param global_schema: Whether to use global schema.
390
391     Added value for tests is that we have a shorthand function
392     to make proper URL for ldb.connect() while using default
393     parameters for connection based on test environment
394     """
395     if not "://" in samdb_url:
396         if not ldap_only and os.path.isfile(samdb_url):
397             samdb_url = "tdb://%s" % samdb_url
398         else:
399             samdb_url = "ldap://%s" % samdb_url
400     # use 'paged_search' module when connecting remotely
401     if samdb_url.startswith("ldap://"):
402         ldb_options = ["modules:paged_searches"]
403     elif ldap_only:
404         raise AssertionError("Trying to connect to %s while remote "
405                              "connection is required" % samdb_url)
406
407     # set defaults for test environment
408     if lp is None:
409         lp = env_loadparm()
410     if session_info is None:
411         session_info = samba.auth.system_session(lp)
412     if credentials is None:
413         credentials = cmdline_credentials
414
415     return SamDB(url=samdb_url,
416                  lp=lp,
417                  session_info=session_info,
418                  credentials=credentials,
419                  flags=flags,
420                  options=ldb_options,
421                  global_schema=global_schema)
422
423
424 def connect_samdb_ex(samdb_url, lp=None, session_info=None, credentials=None,
425                      flags=0, ldb_options=None, ldap_only=False):
426     """Connects to samdb_url database
427
428     :param samdb_url: Url for database to connect to.
429     :param lp: Optional loadparm object
430     :param session_info: Optional session information
431     :param credentials: Optional credentials, defaults to anonymous.
432     :param flags: Optional LDB flags
433     :param ldap_only: If set, only remote LDAP connection will be created.
434     :return: (sam_db_connection, rootDse_record) tuple
435     """
436     sam_db = connect_samdb(samdb_url, lp, session_info, credentials,
437                            flags, ldb_options, ldap_only)
438     # fetch RootDse
439     res = sam_db.search(base="", expression="", scope=ldb.SCOPE_BASE,
440                         attrs=["*"])
441     return (sam_db, res[0])
442
443
444 def connect_samdb_env(env_url, env_username, env_password, lp=None):
445     """Connect to SamDB by getting URL and Credentials from environment
446
447     :param env_url: Environment variable name to get lsb url from
448     :param env_username: Username environment variable
449     :param env_password: Password environment variable
450     :return: sam_db_connection
451     """
452     samdb_url = env_get_var_value(env_url)
453     creds = credentials.Credentials()
454     if lp is None:
455         # guess Credentials parameters here. Otherwise workstation
456         # and domain fields are NULL and gencache code segfalts
457         lp = param.LoadParm()
458         creds.guess(lp)
459     creds.set_username(env_get_var_value(env_username))
460     creds.set_password(env_get_var_value(env_password))
461     return connect_samdb(samdb_url, credentials=creds, lp=lp)
462
463
464 def delete_force(samdb, dn, **kwargs):
465     try:
466         samdb.delete(dn, **kwargs)
467     except ldb.LdbError as error:
468         (num, errstr) = error.args
469         assert num == ldb.ERR_NO_SUCH_OBJECT, "ldb.delete() failed: %s" % errstr