python tests: assert string equality, with diff
[nivanova/samba-autobuild/.git] / python / samba / tests / __init__.py
1 # Unix SMB/CIFS implementation.
2 # Copyright (C) Jelmer Vernooij <jelmer@samba.org> 2007-2010
3 # Copyright (C) Stefan Metzmacher 2014,2015
4 #
5 # This program is free software; you can redistribute it and/or modify
6 # it under the terms of the GNU General Public License as published by
7 # the Free Software Foundation; either version 3 of the License, or
8 # (at your option) any later version.
9 #
10 # This program is distributed in the hope that it will be useful,
11 # but WITHOUT ANY WARRANTY; without even the implied warranty of
12 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 # GNU General Public License for more details.
14 #
15 # You should have received a copy of the GNU General Public License
16 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
17 #
18
19 """Samba Python tests."""
20
21 import os
22 import ldb
23 import samba
24 from samba import param
25 from samba import credentials
26 from samba.credentials import Credentials
27 from samba import gensec
28 import socket
29 import struct
30 import subprocess
31 import sys
32 import tempfile
33 import unittest
34 import samba.auth
35 import samba.dcerpc.base
36 from samba.compat import PY3
37 if not PY3:
38     # Py2 only
39     from samba.samdb import SamDB
40     import samba.ndr
41     import samba.dcerpc.dcerpc
42     import samba.dcerpc.epmapper
43
44 try:
45     from unittest import SkipTest
46 except ImportError:
47     class SkipTest(Exception):
48         """Test skipped."""
49
50 HEXDUMP_FILTER=''.join([(len(repr(chr(x)))==3) and chr(x) or '.' for x in range(256)])
51
52 class TestCase(unittest.TestCase):
53     """A Samba test case."""
54
55     def setUp(self):
56         super(TestCase, self).setUp()
57         test_debug_level = os.getenv("TEST_DEBUG_LEVEL")
58         if test_debug_level is not None:
59             test_debug_level = int(test_debug_level)
60             self._old_debug_level = samba.get_debug_level()
61             samba.set_debug_level(test_debug_level)
62             self.addCleanup(samba.set_debug_level, test_debug_level)
63
64     def get_loadparm(self):
65         return env_loadparm()
66
67     def get_credentials(self):
68         return cmdline_credentials
69
70     def get_creds_ccache_name(self):
71         creds = self.get_credentials()
72         ccache = creds.get_named_ccache(self.get_loadparm())
73         ccache_name = ccache.get_name()
74
75         return ccache_name
76
77     def hexdump(self, src):
78         N = 0
79         result = ''
80         while src:
81             ll = src[:8]
82             lr = src[8:16]
83             src = src[16:]
84             hl = ' '.join(["%02X" % ord(x) for x in ll])
85             hr = ' '.join(["%02X" % ord(x) for x in lr])
86             ll = ll.translate(HEXDUMP_FILTER)
87             lr = lr.translate(HEXDUMP_FILTER)
88             result += "[%04X] %-*s  %-*s  %s %s\n" % (N, 8*3, hl, 8*3, hr, ll, lr)
89             N += 16
90         return result
91
92     def insta_creds(self, template=None, username=None, userpass=None, kerberos_state=None):
93
94         if template is None:
95             assert template is not None
96
97         if username is not None:
98             assert userpass is not None
99
100         if username is None:
101             assert userpass is None
102
103             username = template.get_username()
104             userpass = template.get_password()
105
106         if kerberos_state is None:
107             kerberos_state = template.get_kerberos_state()
108
109         # get a copy of the global creds or a the passed in creds
110         c = Credentials()
111         c.set_username(username)
112         c.set_password(userpass)
113         c.set_domain(template.get_domain())
114         c.set_realm(template.get_realm())
115         c.set_workstation(template.get_workstation())
116         c.set_gensec_features(c.get_gensec_features()
117                               | gensec.FEATURE_SEAL)
118         c.set_kerberos_state(kerberos_state)
119         return c
120
121
122
123     # These functions didn't exist before Python2.7:
124     if sys.version_info < (2, 7):
125         import warnings
126
127         def skipTest(self, reason):
128             raise SkipTest(reason)
129
130         def assertIn(self, member, container, msg=None):
131             self.assertTrue(member in container, msg)
132
133         def assertIs(self, a, b, msg=None):
134             self.assertTrue(a is b, msg)
135
136         def assertIsNot(self, a, b, msg=None):
137             self.assertTrue(a is not b, msg)
138
139         def assertIsNotNone(self, a, msg=None):
140             self.assertTrue(a is not None)
141
142         def assertIsInstance(self, a, b, msg=None):
143             self.assertTrue(isinstance(a, b), msg)
144
145         def assertIsNone(self, a, msg=None):
146             self.assertTrue(a is None, msg)
147
148         def assertGreater(self, a, b, msg=None):
149             self.assertTrue(a > b, msg)
150
151         def assertGreaterEqual(self, a, b, msg=None):
152             self.assertTrue(a >= b, msg)
153
154         def assertLess(self, a, b, msg=None):
155             self.assertTrue(a < b, msg)
156
157         def assertLessEqual(self, a, b, msg=None):
158             self.assertTrue(a <= b, msg)
159
160         def addCleanup(self, fn, *args, **kwargs):
161             self._cleanups = getattr(self, "_cleanups", []) + [
162                 (fn, args, kwargs)]
163
164         def _addSkip(self, result, reason):
165             addSkip = getattr(result, 'addSkip', None)
166             if addSkip is not None:
167                 addSkip(self, reason)
168             else:
169                 warnings.warn("TestResult has no addSkip method, skips not reported",
170                               RuntimeWarning, 2)
171                 result.addSuccess(self)
172
173         def run(self, result=None):
174             if result is None: result = self.defaultTestResult()
175             result.startTest(self)
176             testMethod = getattr(self, self._testMethodName)
177             try:
178                 try:
179                     self.setUp()
180                 except SkipTest as e:
181                     self._addSkip(result, str(e))
182                     return
183                 except KeyboardInterrupt:
184                     raise
185                 except:
186                     result.addError(self, self._exc_info())
187                     return
188
189                 ok = False
190                 try:
191                     testMethod()
192                     ok = True
193                 except SkipTest as e:
194                     self._addSkip(result, str(e))
195                     return
196                 except self.failureException:
197                     result.addFailure(self, self._exc_info())
198                 except KeyboardInterrupt:
199                     raise
200                 except:
201                     result.addError(self, self._exc_info())
202
203                 try:
204                     self.tearDown()
205                 except SkipTest as e:
206                     self._addSkip(result, str(e))
207                 except KeyboardInterrupt:
208                     raise
209                 except:
210                     result.addError(self, self._exc_info())
211                     ok = False
212
213                 for (fn, args, kwargs) in reversed(getattr(self, "_cleanups", [])):
214                     fn(*args, **kwargs)
215                 if ok: result.addSuccess(self)
216             finally:
217                 result.stopTest(self)
218
219     def assertStringsEqual(self, a, b, msg=None, strip=False):
220         """Assert equality between two strings and highlight any differences.
221         If strip is true, leading and trailing whitespace is ignored."""
222         if strip:
223             a = a.strip()
224             b = b.strip()
225
226         if a != b:
227             sys.stderr.write("The strings differ %s(lengths %d vs %d); "
228                              "a diff follows\n"
229                              % ('when stripped ' if strip else '',
230                                 len(a), len(b),
231                              ))
232
233             from difflib import unified_diff
234             diff = unified_diff(a.splitlines(True),
235                                 b.splitlines(True),
236                                 'a', 'b')
237             for line in diff:
238                 sys.stderr.write(line)
239
240             self.fail(msg)
241
242
243 class LdbTestCase(TestCase):
244     """Trivial test case for running tests against a LDB."""
245
246     def setUp(self):
247         super(LdbTestCase, self).setUp()
248         self.filename = os.tempnam()
249         self.ldb = samba.Ldb(self.filename)
250
251     def set_modules(self, modules=[]):
252         """Change the modules for this Ldb."""
253         m = ldb.Message()
254         m.dn = ldb.Dn(self.ldb, "@MODULES")
255         m["@LIST"] = ",".join(modules)
256         self.ldb.add(m)
257         self.ldb = samba.Ldb(self.filename)
258
259
260 class TestCaseInTempDir(TestCase):
261
262     def setUp(self):
263         super(TestCaseInTempDir, self).setUp()
264         self.tempdir = tempfile.mkdtemp()
265         self.addCleanup(self._remove_tempdir)
266
267     def _remove_tempdir(self):
268         self.assertEquals([], os.listdir(self.tempdir))
269         os.rmdir(self.tempdir)
270         self.tempdir = None
271
272
273 def env_loadparm():
274     lp = param.LoadParm()
275     try:
276         lp.load(os.environ["SMB_CONF_PATH"])
277     except KeyError:
278         raise KeyError("SMB_CONF_PATH not set")
279     return lp
280
281
282 def env_get_var_value(var_name, allow_missing=False):
283     """Returns value for variable in os.environ
284
285     Function throws AssertionError if variable is defined.
286     Unit-test based python tests require certain input params
287     to be set in environment, otherwise they can't be run
288     """
289     if allow_missing:
290         if var_name not in os.environ.keys():
291             return None
292     assert var_name in os.environ.keys(), "Please supply %s in environment" % var_name
293     return os.environ[var_name]
294
295
296 cmdline_credentials = None
297
298 class RpcInterfaceTestCase(TestCase):
299     """DCE/RPC Test case."""
300
301
302 class ValidNetbiosNameTests(TestCase):
303
304     def test_valid(self):
305         self.assertTrue(samba.valid_netbios_name("FOO"))
306
307     def test_too_long(self):
308         self.assertFalse(samba.valid_netbios_name("FOO"*10))
309
310     def test_invalid_characters(self):
311         self.assertFalse(samba.valid_netbios_name("*BLA"))
312
313
314 class BlackboxProcessError(Exception):
315     """This is raised when check_output() process returns a non-zero exit status
316
317     Exception instance should contain the exact exit code (S.returncode),
318     command line (S.cmd), process output (S.stdout) and process error stream
319     (S.stderr)
320     """
321
322     def __init__(self, returncode, cmd, stdout, stderr):
323         self.returncode = returncode
324         self.cmd = cmd
325         self.stdout = stdout
326         self.stderr = stderr
327
328     def __str__(self):
329         return "Command '%s'; exit status %d; stdout: '%s'; stderr: '%s'" % (self.cmd, self.returncode,
330                                                                              self.stdout, self.stderr)
331
332 class BlackboxTestCase(TestCaseInTempDir):
333     """Base test case for blackbox tests."""
334
335     def _make_cmdline(self, line):
336         bindir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../../bin"))
337         parts = line.split(" ")
338         if os.path.exists(os.path.join(bindir, parts[0])):
339             parts[0] = os.path.join(bindir, parts[0])
340         line = " ".join(parts)
341         return line
342
343     def check_run(self, line):
344         self.check_exit_code(line, 0)
345
346     def check_exit_code(self, line, expected):
347         line = self._make_cmdline(line)
348         p = subprocess.Popen(line,
349                              stdout=subprocess.PIPE,
350                              stderr=subprocess.PIPE,
351                              shell=True)
352         stdoutdata, stderrdata = p.communicate()
353         retcode = p.returncode
354         if retcode != expected:
355             raise BlackboxProcessError(retcode,
356                                        line,
357                                        stdoutdata,
358                                        stderrdata)
359
360     def check_output(self, line):
361         line = self._make_cmdline(line)
362         p = subprocess.Popen(line, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True, close_fds=True)
363         stdoutdata, stderrdata = p.communicate()
364         retcode = p.returncode
365         if retcode:
366             raise BlackboxProcessError(retcode, line, stdoutdata, stderrdata)
367         return stdoutdata
368
369
370 def connect_samdb(samdb_url, lp=None, session_info=None, credentials=None,
371                   flags=0, ldb_options=None, ldap_only=False, global_schema=True):
372     """Create SamDB instance and connects to samdb_url database.
373
374     :param samdb_url: Url for database to connect to.
375     :param lp: Optional loadparm object
376     :param session_info: Optional session information
377     :param credentials: Optional credentials, defaults to anonymous.
378     :param flags: Optional LDB flags
379     :param ldap_only: If set, only remote LDAP connection will be created.
380     :param global_schema: Whether to use global schema.
381
382     Added value for tests is that we have a shorthand function
383     to make proper URL for ldb.connect() while using default
384     parameters for connection based on test environment
385     """
386     if not "://" in samdb_url:
387         if not ldap_only and os.path.isfile(samdb_url):
388             samdb_url = "tdb://%s" % samdb_url
389         else:
390             samdb_url = "ldap://%s" % samdb_url
391     # use 'paged_search' module when connecting remotely
392     if samdb_url.startswith("ldap://"):
393         ldb_options = ["modules:paged_searches"]
394     elif ldap_only:
395         raise AssertionError("Trying to connect to %s while remote "
396                              "connection is required" % samdb_url)
397
398     # set defaults for test environment
399     if lp is None:
400         lp = env_loadparm()
401     if session_info is None:
402         session_info = samba.auth.system_session(lp)
403     if credentials is None:
404         credentials = cmdline_credentials
405
406     return SamDB(url=samdb_url,
407                  lp=lp,
408                  session_info=session_info,
409                  credentials=credentials,
410                  flags=flags,
411                  options=ldb_options,
412                  global_schema=global_schema)
413
414
415 def connect_samdb_ex(samdb_url, lp=None, session_info=None, credentials=None,
416                      flags=0, ldb_options=None, ldap_only=False):
417     """Connects to samdb_url database
418
419     :param samdb_url: Url for database to connect to.
420     :param lp: Optional loadparm object
421     :param session_info: Optional session information
422     :param credentials: Optional credentials, defaults to anonymous.
423     :param flags: Optional LDB flags
424     :param ldap_only: If set, only remote LDAP connection will be created.
425     :return: (sam_db_connection, rootDse_record) tuple
426     """
427     sam_db = connect_samdb(samdb_url, lp, session_info, credentials,
428                            flags, ldb_options, ldap_only)
429     # fetch RootDse
430     res = sam_db.search(base="", expression="", scope=ldb.SCOPE_BASE,
431                         attrs=["*"])
432     return (sam_db, res[0])
433
434
435 def connect_samdb_env(env_url, env_username, env_password, lp=None):
436     """Connect to SamDB by getting URL and Credentials from environment
437
438     :param env_url: Environment variable name to get lsb url from
439     :param env_username: Username environment variable
440     :param env_password: Password environment variable
441     :return: sam_db_connection
442     """
443     samdb_url = env_get_var_value(env_url)
444     creds = credentials.Credentials()
445     if lp is None:
446         # guess Credentials parameters here. Otherwise workstation
447         # and domain fields are NULL and gencache code segfalts
448         lp = param.LoadParm()
449         creds.guess(lp)
450     creds.set_username(env_get_var_value(env_username))
451     creds.set_password(env_get_var_value(env_password))
452     return connect_samdb(samdb_url, credentials=creds, lp=lp)
453
454
455 def delete_force(samdb, dn, **kwargs):
456     try:
457         samdb.delete(dn, **kwargs)
458     except ldb.LdbError as error:
459         (num, errstr) = error.args
460         assert num == ldb.ERR_NO_SUCH_OBJECT, "ldb.delete() failed: %s" % errstr