b3fccee838348ba74e6504b1b6bedd7cff6d22a4
[nivanova/samba-autobuild/.git] / python / samba / tests / __init__.py
1 # Unix SMB/CIFS implementation.
2 # Copyright (C) Jelmer Vernooij <jelmer@samba.org> 2007-2010
3 # Copyright (C) Stefan Metzmacher 2014,2015
4 #
5 # This program is free software; you can redistribute it and/or modify
6 # it under the terms of the GNU General Public License as published by
7 # the Free Software Foundation; either version 3 of the License, or
8 # (at your option) any later version.
9 #
10 # This program is distributed in the hope that it will be useful,
11 # but WITHOUT ANY WARRANTY; without even the implied warranty of
12 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 # GNU General Public License for more details.
14 #
15 # You should have received a copy of the GNU General Public License
16 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
17 #
18
19 """Samba Python tests."""
20
21 import os
22 import tempfile
23 import ldb
24 import samba
25 from samba import param
26 from samba import credentials
27 from samba.credentials import Credentials
28 from samba import gensec
29 import socket
30 import struct
31 import subprocess
32 import sys
33 import tempfile
34 import unittest
35 import re
36 import samba.auth
37 import samba.dcerpc.base
38 from samba.compat import PY3, text_type
39 from samba.compat import string_types
40 from random import randint
41 try:
42     from samba.samdb import SamDB
43 except ImportError:
44     # We are built without samdb support,
45     # imitate it so that connect_samdb() can recover
46     def SamDB(*args, **kwargs):
47         return None
48
49 import samba.ndr
50 import samba.dcerpc.dcerpc
51 import samba.dcerpc.epmapper
52
53 try:
54     from unittest import SkipTest
55 except ImportError:
56     class SkipTest(Exception):
57         """Test skipped."""
58
59 HEXDUMP_FILTER = bytearray([x if ((len(repr(chr(x))) == 3) and (x < 127)) else ord('.') for x in range(256)])
60
61
62 class TestCase(unittest.TestCase):
63     """A Samba test case."""
64
65     def setUp(self):
66         super(TestCase, self).setUp()
67         test_debug_level = os.getenv("TEST_DEBUG_LEVEL")
68         if test_debug_level is not None:
69             test_debug_level = int(test_debug_level)
70             self._old_debug_level = samba.get_debug_level()
71             samba.set_debug_level(test_debug_level)
72             self.addCleanup(samba.set_debug_level, test_debug_level)
73
74     def get_loadparm(self):
75         return env_loadparm()
76
77     def get_credentials(self):
78         return cmdline_credentials
79
80     def get_creds_ccache_name(self):
81         creds = self.get_credentials()
82         ccache = creds.get_named_ccache(self.get_loadparm())
83         ccache_name = ccache.get_name()
84
85         return ccache_name
86
87     def hexdump(self, src):
88         N = 0
89         result = ''
90         is_string = isinstance(src, string_types)
91         while src:
92             ll = src[:8]
93             lr = src[8:16]
94             src = src[16:]
95             if is_string:
96                 hl = ' '.join(["%02X" % ord(x) for x in ll])
97                 hr = ' '.join(["%02X" % ord(x) for x in lr])
98                 ll = ll.translate(HEXDUMP_FILTER)
99                 lr = lr.translate(HEXDUMP_FILTER)
100             else:
101                 hl = ' '.join(["%02X" % x for x in ll])
102                 hr = ' '.join(["%02X" % x for x in lr])
103                 ll = ll.translate(HEXDUMP_FILTER).decode('utf8')
104                 lr = lr.translate(HEXDUMP_FILTER).decode('utf8')
105             result += "[%04X] %-*s  %-*s  %s %s\n" % (N, 8 * 3, hl, 8 * 3, hr, ll, lr)
106             N += 16
107         return result
108
109     def insta_creds(self, template=None, username=None, userpass=None, kerberos_state=None):
110
111         if template is None:
112             assert template is not None
113
114         if username is not None:
115             assert userpass is not None
116
117         if username is None:
118             assert userpass is None
119
120             username = template.get_username()
121             userpass = template.get_password()
122
123         if kerberos_state is None:
124             kerberos_state = template.get_kerberos_state()
125
126         # get a copy of the global creds or a the passed in creds
127         c = Credentials()
128         c.set_username(username)
129         c.set_password(userpass)
130         c.set_domain(template.get_domain())
131         c.set_realm(template.get_realm())
132         c.set_workstation(template.get_workstation())
133         c.set_gensec_features(c.get_gensec_features()
134                               | gensec.FEATURE_SEAL)
135         c.set_kerberos_state(kerberos_state)
136         return c
137
138     # These functions didn't exist before Python2.7:
139     if sys.version_info < (2, 7):
140         import warnings
141
142         def skipTest(self, reason):
143             raise SkipTest(reason)
144
145         def assertIn(self, member, container, msg=None):
146             self.assertTrue(member in container, msg)
147
148         def assertIs(self, a, b, msg=None):
149             self.assertTrue(a is b, msg)
150
151         def assertIsNot(self, a, b, msg=None):
152             self.assertTrue(a is not b, msg)
153
154         def assertIsNotNone(self, a, msg=None):
155             self.assertTrue(a is not None)
156
157         def assertIsInstance(self, a, b, msg=None):
158             self.assertTrue(isinstance(a, b), msg)
159
160         def assertIsNone(self, a, msg=None):
161             self.assertTrue(a is None, msg)
162
163         def assertGreater(self, a, b, msg=None):
164             self.assertTrue(a > b, msg)
165
166         def assertGreaterEqual(self, a, b, msg=None):
167             self.assertTrue(a >= b, msg)
168
169         def assertLess(self, a, b, msg=None):
170             self.assertTrue(a < b, msg)
171
172         def assertLessEqual(self, a, b, msg=None):
173             self.assertTrue(a <= b, msg)
174
175         def addCleanup(self, fn, *args, **kwargs):
176             self._cleanups = getattr(self, "_cleanups", []) + [
177                 (fn, args, kwargs)]
178
179         def assertRegexpMatches(self, text, regex, msg=None):
180             # PY3 note: Python 3 will never see this, but we use
181             # text_type for the benefit of linters.
182             if isinstance(regex, (str, text_type)):
183                 regex = re.compile(regex)
184             if not regex.search(text):
185                 self.fail(msg)
186
187         def _addSkip(self, result, reason):
188             addSkip = getattr(result, 'addSkip', None)
189             if addSkip is not None:
190                 addSkip(self, reason)
191             else:
192                 warnings.warn("TestResult has no addSkip method, skips not reported",
193                               RuntimeWarning, 2)
194                 result.addSuccess(self)
195
196         def run(self, result=None):
197             if result is None:
198                 result = self.defaultTestResult()
199             result.startTest(self)
200             testMethod = getattr(self, self._testMethodName)
201             try:
202                 try:
203                     self.setUp()
204                 except SkipTest as e:
205                     self._addSkip(result, str(e))
206                     return
207                 except KeyboardInterrupt:
208                     raise
209                 except:
210                     result.addError(self, self._exc_info())
211                     return
212
213                 ok = False
214                 try:
215                     testMethod()
216                     ok = True
217                 except SkipTest as e:
218                     self._addSkip(result, str(e))
219                     return
220                 except self.failureException:
221                     result.addFailure(self, self._exc_info())
222                 except KeyboardInterrupt:
223                     raise
224                 except:
225                     result.addError(self, self._exc_info())
226
227                 try:
228                     self.tearDown()
229                 except SkipTest as e:
230                     self._addSkip(result, str(e))
231                 except KeyboardInterrupt:
232                     raise
233                 except:
234                     result.addError(self, self._exc_info())
235                     ok = False
236
237                 for (fn, args, kwargs) in reversed(getattr(self, "_cleanups", [])):
238                     fn(*args, **kwargs)
239                 if ok:
240                     result.addSuccess(self)
241             finally:
242                 result.stopTest(self)
243
244     def assertStringsEqual(self, a, b, msg=None, strip=False):
245         """Assert equality between two strings and highlight any differences.
246         If strip is true, leading and trailing whitespace is ignored."""
247         if strip:
248             a = a.strip()
249             b = b.strip()
250
251         if a != b:
252             sys.stderr.write("The strings differ %s(lengths %d vs %d); "
253                              "a diff follows\n"
254                              % ('when stripped ' if strip else '',
255                                 len(a), len(b),
256                                 ))
257
258             from difflib import unified_diff
259             diff = unified_diff(a.splitlines(True),
260                                 b.splitlines(True),
261                                 'a', 'b')
262             for line in diff:
263                 sys.stderr.write(line)
264
265             self.fail(msg)
266
267
268 class LdbTestCase(TestCase):
269     """Trivial test case for running tests against a LDB."""
270
271     def setUp(self):
272         super(LdbTestCase, self).setUp()
273         self.tempfile = tempfile.NamedTemporaryFile(delete=False)
274         self.filename = self.tempfile.name
275         self.ldb = samba.Ldb(self.filename)
276
277     def set_modules(self, modules=[]):
278         """Change the modules for this Ldb."""
279         m = ldb.Message()
280         m.dn = ldb.Dn(self.ldb, "@MODULES")
281         m["@LIST"] = ",".join(modules)
282         self.ldb.add(m)
283         self.ldb = samba.Ldb(self.filename)
284
285
286 class TestCaseInTempDir(TestCase):
287
288     def setUp(self):
289         super(TestCaseInTempDir, self).setUp()
290         self.tempdir = tempfile.mkdtemp()
291         self.addCleanup(self._remove_tempdir)
292
293     def _remove_tempdir(self):
294         self.assertEquals([], os.listdir(self.tempdir))
295         os.rmdir(self.tempdir)
296         self.tempdir = None
297
298
299 def env_loadparm():
300     lp = param.LoadParm()
301     try:
302         lp.load(os.environ["SMB_CONF_PATH"])
303     except KeyError:
304         raise KeyError("SMB_CONF_PATH not set")
305     return lp
306
307
308 def env_get_var_value(var_name, allow_missing=False):
309     """Returns value for variable in os.environ
310
311     Function throws AssertionError if variable is defined.
312     Unit-test based python tests require certain input params
313     to be set in environment, otherwise they can't be run
314     """
315     if allow_missing:
316         if var_name not in os.environ.keys():
317             return None
318     assert var_name in os.environ.keys(), "Please supply %s in environment" % var_name
319     return os.environ[var_name]
320
321
322 cmdline_credentials = None
323
324
325 class RpcInterfaceTestCase(TestCase):
326     """DCE/RPC Test case."""
327
328
329 class ValidNetbiosNameTests(TestCase):
330
331     def test_valid(self):
332         self.assertTrue(samba.valid_netbios_name("FOO"))
333
334     def test_too_long(self):
335         self.assertFalse(samba.valid_netbios_name("FOO" * 10))
336
337     def test_invalid_characters(self):
338         self.assertFalse(samba.valid_netbios_name("*BLA"))
339
340
341 class BlackboxProcessError(Exception):
342     """This is raised when check_output() process returns a non-zero exit status
343
344     Exception instance should contain the exact exit code (S.returncode),
345     command line (S.cmd), process output (S.stdout) and process error stream
346     (S.stderr)
347     """
348
349     def __init__(self, returncode, cmd, stdout, stderr, msg=None):
350         self.returncode = returncode
351         self.cmd = cmd
352         self.stdout = stdout
353         self.stderr = stderr
354         self.msg = msg
355
356     def __str__(self):
357         s = ("Command '%s'; exit status %d; stdout: '%s'; stderr: '%s'" %
358              (self.cmd, self.returncode, self.stdout, self.stderr))
359         if self.msg is not None:
360             s = "%s; message: %s" % (s, self.msg)
361
362         return s
363
364
365 class BlackboxTestCase(TestCaseInTempDir):
366     """Base test case for blackbox tests."""
367
368     def _make_cmdline(self, line):
369         bindir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../../bin"))
370         parts = line.split(" ")
371         if os.path.exists(os.path.join(bindir, parts[0])):
372             parts[0] = os.path.join(bindir, parts[0])
373         line = " ".join(parts)
374         return line
375
376     def check_run(self, line, msg=None):
377         self.check_exit_code(line, 0, msg=msg)
378
379     def check_exit_code(self, line, expected, msg=None):
380         line = self._make_cmdline(line)
381         p = subprocess.Popen(line,
382                              stdout=subprocess.PIPE,
383                              stderr=subprocess.PIPE,
384                              shell=True)
385         stdoutdata, stderrdata = p.communicate()
386         retcode = p.returncode
387         if retcode != expected:
388             raise BlackboxProcessError(retcode,
389                                        line,
390                                        stdoutdata,
391                                        stderrdata,
392                                        msg)
393
394     def check_output(self, line):
395         line = self._make_cmdline(line)
396         p = subprocess.Popen(line, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True, close_fds=True)
397         stdoutdata, stderrdata = p.communicate()
398         retcode = p.returncode
399         if retcode:
400             raise BlackboxProcessError(retcode, line, stdoutdata, stderrdata)
401         return stdoutdata
402
403
404 def connect_samdb(samdb_url, lp=None, session_info=None, credentials=None,
405                   flags=0, ldb_options=None, ldap_only=False, global_schema=True):
406     """Create SamDB instance and connects to samdb_url database.
407
408     :param samdb_url: Url for database to connect to.
409     :param lp: Optional loadparm object
410     :param session_info: Optional session information
411     :param credentials: Optional credentials, defaults to anonymous.
412     :param flags: Optional LDB flags
413     :param ldap_only: If set, only remote LDAP connection will be created.
414     :param global_schema: Whether to use global schema.
415
416     Added value for tests is that we have a shorthand function
417     to make proper URL for ldb.connect() while using default
418     parameters for connection based on test environment
419     """
420     if not "://" in samdb_url:
421         if not ldap_only and os.path.isfile(samdb_url):
422             samdb_url = "tdb://%s" % samdb_url
423         else:
424             samdb_url = "ldap://%s" % samdb_url
425     # use 'paged_search' module when connecting remotely
426     if samdb_url.startswith("ldap://"):
427         ldb_options = ["modules:paged_searches"]
428     elif ldap_only:
429         raise AssertionError("Trying to connect to %s while remote "
430                              "connection is required" % samdb_url)
431
432     # set defaults for test environment
433     if lp is None:
434         lp = env_loadparm()
435     if session_info is None:
436         session_info = samba.auth.system_session(lp)
437     if credentials is None:
438         credentials = cmdline_credentials
439
440     return SamDB(url=samdb_url,
441                  lp=lp,
442                  session_info=session_info,
443                  credentials=credentials,
444                  flags=flags,
445                  options=ldb_options,
446                  global_schema=global_schema)
447
448
449 def connect_samdb_ex(samdb_url, lp=None, session_info=None, credentials=None,
450                      flags=0, ldb_options=None, ldap_only=False):
451     """Connects to samdb_url database
452
453     :param samdb_url: Url for database to connect to.
454     :param lp: Optional loadparm object
455     :param session_info: Optional session information
456     :param credentials: Optional credentials, defaults to anonymous.
457     :param flags: Optional LDB flags
458     :param ldap_only: If set, only remote LDAP connection will be created.
459     :return: (sam_db_connection, rootDse_record) tuple
460     """
461     sam_db = connect_samdb(samdb_url, lp, session_info, credentials,
462                            flags, ldb_options, ldap_only)
463     # fetch RootDse
464     res = sam_db.search(base="", expression="", scope=ldb.SCOPE_BASE,
465                         attrs=["*"])
466     return (sam_db, res[0])
467
468
469 def connect_samdb_env(env_url, env_username, env_password, lp=None):
470     """Connect to SamDB by getting URL and Credentials from environment
471
472     :param env_url: Environment variable name to get lsb url from
473     :param env_username: Username environment variable
474     :param env_password: Password environment variable
475     :return: sam_db_connection
476     """
477     samdb_url = env_get_var_value(env_url)
478     creds = credentials.Credentials()
479     if lp is None:
480         # guess Credentials parameters here. Otherwise workstation
481         # and domain fields are NULL and gencache code segfalts
482         lp = param.LoadParm()
483         creds.guess(lp)
484     creds.set_username(env_get_var_value(env_username))
485     creds.set_password(env_get_var_value(env_password))
486     return connect_samdb(samdb_url, credentials=creds, lp=lp)
487
488
489 def delete_force(samdb, dn, **kwargs):
490     try:
491         samdb.delete(dn, **kwargs)
492     except ldb.LdbError as error:
493         (num, errstr) = error.args
494         assert num == ldb.ERR_NO_SUCH_OBJECT, "ldb.delete() failed: %s" % errstr
495
496
497 def create_test_ou(samdb, name):
498     """Creates a unique OU for the test"""
499
500     # Add some randomness to the test OU. Replication between the testenvs is
501     # constantly happening in the background. Deletion of the last test's
502     # objects can be slow to replicate out. So the OU created by a previous
503     # testenv may still exist at the point that tests start on another testenv.
504     rand = randint(1, 10000000)
505     dn = ldb.Dn(samdb, "OU=%s%d,%s" % (name, rand, samdb.get_default_basedn()))
506     samdb.add({"dn": dn, "objectclass": "organizationalUnit"})
507     return dn