bd17a870d66780dfd452e93d22f4338b61ecd92a
[nivanova/samba-autobuild/.git] / python / samba / tests / __init__.py
1 # Unix SMB/CIFS implementation.
2 # Copyright (C) Jelmer Vernooij <jelmer@samba.org> 2007-2010
3 # Copyright (C) Stefan Metzmacher 2014,2015
4 #
5 # This program is free software; you can redistribute it and/or modify
6 # it under the terms of the GNU General Public License as published by
7 # the Free Software Foundation; either version 3 of the License, or
8 # (at your option) any later version.
9 #
10 # This program is distributed in the hope that it will be useful,
11 # but WITHOUT ANY WARRANTY; without even the implied warranty of
12 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 # GNU General Public License for more details.
14 #
15 # You should have received a copy of the GNU General Public License
16 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
17 #
18
19 """Samba Python tests."""
20
21 import os
22 import tempfile
23 import ldb
24 import samba
25 from samba import param
26 from samba import credentials
27 from samba.credentials import Credentials
28 from samba import gensec
29 import socket
30 import struct
31 import subprocess
32 import sys
33 import tempfile
34 import unittest
35 import re
36 import samba.auth
37 import samba.dcerpc.base
38 from samba.compat import PY3, text_type
39 from samba.compat import string_types
40 from random import randint
41 try:
42     from samba.samdb import SamDB
43 except ImportError:
44     # We are built without samdb support,
45     # imitate it so that connect_samdb() can recover
46     def SamDB(*args, **kwargs):
47         return None
48
49 import samba.ndr
50 import samba.dcerpc.dcerpc
51 import samba.dcerpc.epmapper
52
53 try:
54     from unittest import SkipTest
55 except ImportError:
56     class SkipTest(Exception):
57         """Test skipped."""
58
59 HEXDUMP_FILTER = bytearray([x if ((len(repr(chr(x))) == 3) and (x < 127)) else ord('.') for x in range(256)])
60
61 class TestCase(unittest.TestCase):
62     """A Samba test case."""
63
64     def setUp(self):
65         super(TestCase, self).setUp()
66         test_debug_level = os.getenv("TEST_DEBUG_LEVEL")
67         if test_debug_level is not None:
68             test_debug_level = int(test_debug_level)
69             self._old_debug_level = samba.get_debug_level()
70             samba.set_debug_level(test_debug_level)
71             self.addCleanup(samba.set_debug_level, test_debug_level)
72
73     def get_loadparm(self):
74         return env_loadparm()
75
76     def get_credentials(self):
77         return cmdline_credentials
78
79     def get_creds_ccache_name(self):
80         creds = self.get_credentials()
81         ccache = creds.get_named_ccache(self.get_loadparm())
82         ccache_name = ccache.get_name()
83
84         return ccache_name
85
86     def hexdump(self, src):
87         N = 0
88         result = ''
89         is_string = isinstance(src, string_types)
90         while src:
91             ll = src[:8]
92             lr = src[8:16]
93             src = src[16:]
94             if is_string:
95                 hl = ' '.join(["%02X" % ord(x) for x in ll])
96                 hr = ' '.join(["%02X" % ord(x) for x in lr])
97                 ll = ll.translate(HEXDUMP_FILTER)
98                 lr = lr.translate(HEXDUMP_FILTER)
99             else:
100                 hl = ' '.join(["%02X" % x for x in ll])
101                 hr = ' '.join(["%02X" % x for x in lr])
102                 ll = ll.translate(HEXDUMP_FILTER).decode('utf8')
103                 lr = lr.translate(HEXDUMP_FILTER).decode('utf8')
104             result += "[%04X] %-*s  %-*s  %s %s\n" % (N, 8 * 3, hl, 8 * 3, hr, ll, lr)
105             N += 16
106         return result
107
108     def insta_creds(self, template=None, username=None, userpass=None, kerberos_state=None):
109
110         if template is None:
111             assert template is not None
112
113         if username is not None:
114             assert userpass is not None
115
116         if username is None:
117             assert userpass is None
118
119             username = template.get_username()
120             userpass = template.get_password()
121
122         if kerberos_state is None:
123             kerberos_state = template.get_kerberos_state()
124
125         # get a copy of the global creds or a the passed in creds
126         c = Credentials()
127         c.set_username(username)
128         c.set_password(userpass)
129         c.set_domain(template.get_domain())
130         c.set_realm(template.get_realm())
131         c.set_workstation(template.get_workstation())
132         c.set_gensec_features(c.get_gensec_features()
133                               | gensec.FEATURE_SEAL)
134         c.set_kerberos_state(kerberos_state)
135         return c
136
137
138
139     # These functions didn't exist before Python2.7:
140     if sys.version_info < (2, 7):
141         import warnings
142
143         def skipTest(self, reason):
144             raise SkipTest(reason)
145
146         def assertIn(self, member, container, msg=None):
147             self.assertTrue(member in container, msg)
148
149         def assertIs(self, a, b, msg=None):
150             self.assertTrue(a is b, msg)
151
152         def assertIsNot(self, a, b, msg=None):
153             self.assertTrue(a is not b, msg)
154
155         def assertIsNotNone(self, a, msg=None):
156             self.assertTrue(a is not None)
157
158         def assertIsInstance(self, a, b, msg=None):
159             self.assertTrue(isinstance(a, b), msg)
160
161         def assertIsNone(self, a, msg=None):
162             self.assertTrue(a is None, msg)
163
164         def assertGreater(self, a, b, msg=None):
165             self.assertTrue(a > b, msg)
166
167         def assertGreaterEqual(self, a, b, msg=None):
168             self.assertTrue(a >= b, msg)
169
170         def assertLess(self, a, b, msg=None):
171             self.assertTrue(a < b, msg)
172
173         def assertLessEqual(self, a, b, msg=None):
174             self.assertTrue(a <= b, msg)
175
176         def addCleanup(self, fn, *args, **kwargs):
177             self._cleanups = getattr(self, "_cleanups", []) + [
178                 (fn, args, kwargs)]
179
180         def assertRegexpMatches(self, text, regex, msg=None):
181             # PY3 note: Python 3 will never see this, but we use
182             # text_type for the benefit of linters.
183             if isinstance(regex, (str, text_type)):
184                 regex = re.compile(regex)
185             if not regex.search(text):
186                 self.fail(msg)
187
188         def _addSkip(self, result, reason):
189             addSkip = getattr(result, 'addSkip', None)
190             if addSkip is not None:
191                 addSkip(self, reason)
192             else:
193                 warnings.warn("TestResult has no addSkip method, skips not reported",
194                               RuntimeWarning, 2)
195                 result.addSuccess(self)
196
197         def run(self, result=None):
198             if result is None: result = self.defaultTestResult()
199             result.startTest(self)
200             testMethod = getattr(self, self._testMethodName)
201             try:
202                 try:
203                     self.setUp()
204                 except SkipTest as e:
205                     self._addSkip(result, str(e))
206                     return
207                 except KeyboardInterrupt:
208                     raise
209                 except:
210                     result.addError(self, self._exc_info())
211                     return
212
213                 ok = False
214                 try:
215                     testMethod()
216                     ok = True
217                 except SkipTest as e:
218                     self._addSkip(result, str(e))
219                     return
220                 except self.failureException:
221                     result.addFailure(self, self._exc_info())
222                 except KeyboardInterrupt:
223                     raise
224                 except:
225                     result.addError(self, self._exc_info())
226
227                 try:
228                     self.tearDown()
229                 except SkipTest as e:
230                     self._addSkip(result, str(e))
231                 except KeyboardInterrupt:
232                     raise
233                 except:
234                     result.addError(self, self._exc_info())
235                     ok = False
236
237                 for (fn, args, kwargs) in reversed(getattr(self, "_cleanups", [])):
238                     fn(*args, **kwargs)
239                 if ok: result.addSuccess(self)
240             finally:
241                 result.stopTest(self)
242
243     def assertStringsEqual(self, a, b, msg=None, strip=False):
244         """Assert equality between two strings and highlight any differences.
245         If strip is true, leading and trailing whitespace is ignored."""
246         if strip:
247             a = a.strip()
248             b = b.strip()
249
250         if a != b:
251             sys.stderr.write("The strings differ %s(lengths %d vs %d); "
252                              "a diff follows\n"
253                              % ('when stripped ' if strip else '',
254                                 len(a), len(b),
255                                 ))
256
257             from difflib import unified_diff
258             diff = unified_diff(a.splitlines(True),
259                                 b.splitlines(True),
260                                 'a', 'b')
261             for line in diff:
262                 sys.stderr.write(line)
263
264             self.fail(msg)
265
266
267 class LdbTestCase(TestCase):
268     """Trivial test case for running tests against a LDB."""
269
270     def setUp(self):
271         super(LdbTestCase, self).setUp()
272         self.tempfile = tempfile.NamedTemporaryFile(delete=False)
273         self.filename = self.tempfile.name
274         self.ldb = samba.Ldb(self.filename)
275
276     def set_modules(self, modules=[]):
277         """Change the modules for this Ldb."""
278         m = ldb.Message()
279         m.dn = ldb.Dn(self.ldb, "@MODULES")
280         m["@LIST"] = ",".join(modules)
281         self.ldb.add(m)
282         self.ldb = samba.Ldb(self.filename)
283
284
285 class TestCaseInTempDir(TestCase):
286
287     def setUp(self):
288         super(TestCaseInTempDir, self).setUp()
289         self.tempdir = tempfile.mkdtemp()
290         self.addCleanup(self._remove_tempdir)
291
292     def _remove_tempdir(self):
293         self.assertEquals([], os.listdir(self.tempdir))
294         os.rmdir(self.tempdir)
295         self.tempdir = None
296
297
298 def env_loadparm():
299     lp = param.LoadParm()
300     try:
301         lp.load(os.environ["SMB_CONF_PATH"])
302     except KeyError:
303         raise KeyError("SMB_CONF_PATH not set")
304     return lp
305
306
307 def env_get_var_value(var_name, allow_missing=False):
308     """Returns value for variable in os.environ
309
310     Function throws AssertionError if variable is defined.
311     Unit-test based python tests require certain input params
312     to be set in environment, otherwise they can't be run
313     """
314     if allow_missing:
315         if var_name not in os.environ.keys():
316             return None
317     assert var_name in os.environ.keys(), "Please supply %s in environment" % var_name
318     return os.environ[var_name]
319
320
321 cmdline_credentials = None
322
323 class RpcInterfaceTestCase(TestCase):
324     """DCE/RPC Test case."""
325
326
327 class ValidNetbiosNameTests(TestCase):
328
329     def test_valid(self):
330         self.assertTrue(samba.valid_netbios_name("FOO"))
331
332     def test_too_long(self):
333         self.assertFalse(samba.valid_netbios_name("FOO" * 10))
334
335     def test_invalid_characters(self):
336         self.assertFalse(samba.valid_netbios_name("*BLA"))
337
338
339 class BlackboxProcessError(Exception):
340     """This is raised when check_output() process returns a non-zero exit status
341
342     Exception instance should contain the exact exit code (S.returncode),
343     command line (S.cmd), process output (S.stdout) and process error stream
344     (S.stderr)
345     """
346
347     def __init__(self, returncode, cmd, stdout, stderr, msg=None):
348         self.returncode = returncode
349         self.cmd = cmd
350         self.stdout = stdout
351         self.stderr = stderr
352         self.msg = msg
353
354     def __str__(self):
355         s = ("Command '%s'; exit status %d; stdout: '%s'; stderr: '%s'" %
356              (self.cmd, self.returncode, self.stdout, self.stderr))
357         if self.msg is not None:
358             s = "%s; message: %s" % (s, self.msg)
359
360         return s
361
362 class BlackboxTestCase(TestCaseInTempDir):
363     """Base test case for blackbox tests."""
364
365     def _make_cmdline(self, line):
366         bindir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../../bin"))
367         parts = line.split(" ")
368         if os.path.exists(os.path.join(bindir, parts[0])):
369             parts[0] = os.path.join(bindir, parts[0])
370         line = " ".join(parts)
371         return line
372
373     def check_run(self, line, msg=None):
374         self.check_exit_code(line, 0, msg=msg)
375
376     def check_exit_code(self, line, expected, msg=None):
377         line = self._make_cmdline(line)
378         p = subprocess.Popen(line,
379                              stdout=subprocess.PIPE,
380                              stderr=subprocess.PIPE,
381                              shell=True)
382         stdoutdata, stderrdata = p.communicate()
383         retcode = p.returncode
384         if retcode != expected:
385             raise BlackboxProcessError(retcode,
386                                        line,
387                                        stdoutdata,
388                                        stderrdata,
389                                        msg)
390
391     def check_output(self, line):
392         line = self._make_cmdline(line)
393         p = subprocess.Popen(line, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True, close_fds=True)
394         stdoutdata, stderrdata = p.communicate()
395         retcode = p.returncode
396         if retcode:
397             raise BlackboxProcessError(retcode, line, stdoutdata, stderrdata)
398         return stdoutdata
399
400
401 def connect_samdb(samdb_url, lp=None, session_info=None, credentials=None,
402                   flags=0, ldb_options=None, ldap_only=False, global_schema=True):
403     """Create SamDB instance and connects to samdb_url database.
404
405     :param samdb_url: Url for database to connect to.
406     :param lp: Optional loadparm object
407     :param session_info: Optional session information
408     :param credentials: Optional credentials, defaults to anonymous.
409     :param flags: Optional LDB flags
410     :param ldap_only: If set, only remote LDAP connection will be created.
411     :param global_schema: Whether to use global schema.
412
413     Added value for tests is that we have a shorthand function
414     to make proper URL for ldb.connect() while using default
415     parameters for connection based on test environment
416     """
417     if not "://" in samdb_url:
418         if not ldap_only and os.path.isfile(samdb_url):
419             samdb_url = "tdb://%s" % samdb_url
420         else:
421             samdb_url = "ldap://%s" % samdb_url
422     # use 'paged_search' module when connecting remotely
423     if samdb_url.startswith("ldap://"):
424         ldb_options = ["modules:paged_searches"]
425     elif ldap_only:
426         raise AssertionError("Trying to connect to %s while remote "
427                              "connection is required" % samdb_url)
428
429     # set defaults for test environment
430     if lp is None:
431         lp = env_loadparm()
432     if session_info is None:
433         session_info = samba.auth.system_session(lp)
434     if credentials is None:
435         credentials = cmdline_credentials
436
437     return SamDB(url=samdb_url,
438                  lp=lp,
439                  session_info=session_info,
440                  credentials=credentials,
441                  flags=flags,
442                  options=ldb_options,
443                  global_schema=global_schema)
444
445
446 def connect_samdb_ex(samdb_url, lp=None, session_info=None, credentials=None,
447                      flags=0, ldb_options=None, ldap_only=False):
448     """Connects to samdb_url database
449
450     :param samdb_url: Url for database to connect to.
451     :param lp: Optional loadparm object
452     :param session_info: Optional session information
453     :param credentials: Optional credentials, defaults to anonymous.
454     :param flags: Optional LDB flags
455     :param ldap_only: If set, only remote LDAP connection will be created.
456     :return: (sam_db_connection, rootDse_record) tuple
457     """
458     sam_db = connect_samdb(samdb_url, lp, session_info, credentials,
459                            flags, ldb_options, ldap_only)
460     # fetch RootDse
461     res = sam_db.search(base="", expression="", scope=ldb.SCOPE_BASE,
462                         attrs=["*"])
463     return (sam_db, res[0])
464
465
466 def connect_samdb_env(env_url, env_username, env_password, lp=None):
467     """Connect to SamDB by getting URL and Credentials from environment
468
469     :param env_url: Environment variable name to get lsb url from
470     :param env_username: Username environment variable
471     :param env_password: Password environment variable
472     :return: sam_db_connection
473     """
474     samdb_url = env_get_var_value(env_url)
475     creds = credentials.Credentials()
476     if lp is None:
477         # guess Credentials parameters here. Otherwise workstation
478         # and domain fields are NULL and gencache code segfalts
479         lp = param.LoadParm()
480         creds.guess(lp)
481     creds.set_username(env_get_var_value(env_username))
482     creds.set_password(env_get_var_value(env_password))
483     return connect_samdb(samdb_url, credentials=creds, lp=lp)
484
485
486 def delete_force(samdb, dn, **kwargs):
487     try:
488         samdb.delete(dn, **kwargs)
489     except ldb.LdbError as error:
490         (num, errstr) = error.args
491         assert num == ldb.ERR_NO_SUCH_OBJECT, "ldb.delete() failed: %s" % errstr
492
493 def create_test_ou(samdb, name):
494     """Creates a unique OU for the test"""
495
496     # Add some randomness to the test OU. Replication between the testenvs is
497     # constantly happening in the background. Deletion of the last test's
498     # objects can be slow to replicate out. So the OU created by a previous
499     # testenv may still exist at the point that tests start on another testenv.
500     rand = randint(1, 10000000)
501     dn = ldb.Dn(samdb, "OU=%s%d,%s" % (name, rand, samdb.get_default_basedn()))
502     samdb.add({"dn": dn, "objectclass": "organizationalUnit"})
503     return dn