PEP8: fix E302: expected 2 blank lines, found 1
[garming/samba-autobuild/.git] / python / samba / tests / __init__.py
1 # Unix SMB/CIFS implementation.
2 # Copyright (C) Jelmer Vernooij <jelmer@samba.org> 2007-2010
3 # Copyright (C) Stefan Metzmacher 2014,2015
4 #
5 # This program is free software; you can redistribute it and/or modify
6 # it under the terms of the GNU General Public License as published by
7 # the Free Software Foundation; either version 3 of the License, or
8 # (at your option) any later version.
9 #
10 # This program is distributed in the hope that it will be useful,
11 # but WITHOUT ANY WARRANTY; without even the implied warranty of
12 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 # GNU General Public License for more details.
14 #
15 # You should have received a copy of the GNU General Public License
16 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
17 #
18
19 """Samba Python tests."""
20
21 import os
22 import tempfile
23 import ldb
24 import samba
25 from samba import param
26 from samba import credentials
27 from samba.credentials import Credentials
28 from samba import gensec
29 import socket
30 import struct
31 import subprocess
32 import sys
33 import tempfile
34 import unittest
35 import re
36 import samba.auth
37 import samba.dcerpc.base
38 from samba.compat import PY3, text_type
39 from samba.compat import string_types
40 from random import randint
41 try:
42     from samba.samdb import SamDB
43 except ImportError:
44     # We are built without samdb support,
45     # imitate it so that connect_samdb() can recover
46     def SamDB(*args, **kwargs):
47         return None
48
49 import samba.ndr
50 import samba.dcerpc.dcerpc
51 import samba.dcerpc.epmapper
52
53 try:
54     from unittest import SkipTest
55 except ImportError:
56     class SkipTest(Exception):
57         """Test skipped."""
58
59 HEXDUMP_FILTER = bytearray([x if ((len(repr(chr(x))) == 3) and (x < 127)) else ord('.') for x in range(256)])
60
61
62 class TestCase(unittest.TestCase):
63     """A Samba test case."""
64
65     def setUp(self):
66         super(TestCase, self).setUp()
67         test_debug_level = os.getenv("TEST_DEBUG_LEVEL")
68         if test_debug_level is not None:
69             test_debug_level = int(test_debug_level)
70             self._old_debug_level = samba.get_debug_level()
71             samba.set_debug_level(test_debug_level)
72             self.addCleanup(samba.set_debug_level, test_debug_level)
73
74     def get_loadparm(self):
75         return env_loadparm()
76
77     def get_credentials(self):
78         return cmdline_credentials
79
80     def get_creds_ccache_name(self):
81         creds = self.get_credentials()
82         ccache = creds.get_named_ccache(self.get_loadparm())
83         ccache_name = ccache.get_name()
84
85         return ccache_name
86
87     def hexdump(self, src):
88         N = 0
89         result = ''
90         is_string = isinstance(src, string_types)
91         while src:
92             ll = src[:8]
93             lr = src[8:16]
94             src = src[16:]
95             if is_string:
96                 hl = ' '.join(["%02X" % ord(x) for x in ll])
97                 hr = ' '.join(["%02X" % ord(x) for x in lr])
98                 ll = ll.translate(HEXDUMP_FILTER)
99                 lr = lr.translate(HEXDUMP_FILTER)
100             else:
101                 hl = ' '.join(["%02X" % x for x in ll])
102                 hr = ' '.join(["%02X" % x for x in lr])
103                 ll = ll.translate(HEXDUMP_FILTER).decode('utf8')
104                 lr = lr.translate(HEXDUMP_FILTER).decode('utf8')
105             result += "[%04X] %-*s  %-*s  %s %s\n" % (N, 8 * 3, hl, 8 * 3, hr, ll, lr)
106             N += 16
107         return result
108
109     def insta_creds(self, template=None, username=None, userpass=None, kerberos_state=None):
110
111         if template is None:
112             assert template is not None
113
114         if username is not None:
115             assert userpass is not None
116
117         if username is None:
118             assert userpass is None
119
120             username = template.get_username()
121             userpass = template.get_password()
122
123         if kerberos_state is None:
124             kerberos_state = template.get_kerberos_state()
125
126         # get a copy of the global creds or a the passed in creds
127         c = Credentials()
128         c.set_username(username)
129         c.set_password(userpass)
130         c.set_domain(template.get_domain())
131         c.set_realm(template.get_realm())
132         c.set_workstation(template.get_workstation())
133         c.set_gensec_features(c.get_gensec_features()
134                               | gensec.FEATURE_SEAL)
135         c.set_kerberos_state(kerberos_state)
136         return c
137
138
139
140     # These functions didn't exist before Python2.7:
141     if sys.version_info < (2, 7):
142         import warnings
143
144         def skipTest(self, reason):
145             raise SkipTest(reason)
146
147         def assertIn(self, member, container, msg=None):
148             self.assertTrue(member in container, msg)
149
150         def assertIs(self, a, b, msg=None):
151             self.assertTrue(a is b, msg)
152
153         def assertIsNot(self, a, b, msg=None):
154             self.assertTrue(a is not b, msg)
155
156         def assertIsNotNone(self, a, msg=None):
157             self.assertTrue(a is not None)
158
159         def assertIsInstance(self, a, b, msg=None):
160             self.assertTrue(isinstance(a, b), msg)
161
162         def assertIsNone(self, a, msg=None):
163             self.assertTrue(a is None, msg)
164
165         def assertGreater(self, a, b, msg=None):
166             self.assertTrue(a > b, msg)
167
168         def assertGreaterEqual(self, a, b, msg=None):
169             self.assertTrue(a >= b, msg)
170
171         def assertLess(self, a, b, msg=None):
172             self.assertTrue(a < b, msg)
173
174         def assertLessEqual(self, a, b, msg=None):
175             self.assertTrue(a <= b, msg)
176
177         def addCleanup(self, fn, *args, **kwargs):
178             self._cleanups = getattr(self, "_cleanups", []) + [
179                 (fn, args, kwargs)]
180
181         def assertRegexpMatches(self, text, regex, msg=None):
182             # PY3 note: Python 3 will never see this, but we use
183             # text_type for the benefit of linters.
184             if isinstance(regex, (str, text_type)):
185                 regex = re.compile(regex)
186             if not regex.search(text):
187                 self.fail(msg)
188
189         def _addSkip(self, result, reason):
190             addSkip = getattr(result, 'addSkip', None)
191             if addSkip is not None:
192                 addSkip(self, reason)
193             else:
194                 warnings.warn("TestResult has no addSkip method, skips not reported",
195                               RuntimeWarning, 2)
196                 result.addSuccess(self)
197
198         def run(self, result=None):
199             if result is None: result = self.defaultTestResult()
200             result.startTest(self)
201             testMethod = getattr(self, self._testMethodName)
202             try:
203                 try:
204                     self.setUp()
205                 except SkipTest as e:
206                     self._addSkip(result, str(e))
207                     return
208                 except KeyboardInterrupt:
209                     raise
210                 except:
211                     result.addError(self, self._exc_info())
212                     return
213
214                 ok = False
215                 try:
216                     testMethod()
217                     ok = True
218                 except SkipTest as e:
219                     self._addSkip(result, str(e))
220                     return
221                 except self.failureException:
222                     result.addFailure(self, self._exc_info())
223                 except KeyboardInterrupt:
224                     raise
225                 except:
226                     result.addError(self, self._exc_info())
227
228                 try:
229                     self.tearDown()
230                 except SkipTest as e:
231                     self._addSkip(result, str(e))
232                 except KeyboardInterrupt:
233                     raise
234                 except:
235                     result.addError(self, self._exc_info())
236                     ok = False
237
238                 for (fn, args, kwargs) in reversed(getattr(self, "_cleanups", [])):
239                     fn(*args, **kwargs)
240                 if ok: result.addSuccess(self)
241             finally:
242                 result.stopTest(self)
243
244     def assertStringsEqual(self, a, b, msg=None, strip=False):
245         """Assert equality between two strings and highlight any differences.
246         If strip is true, leading and trailing whitespace is ignored."""
247         if strip:
248             a = a.strip()
249             b = b.strip()
250
251         if a != b:
252             sys.stderr.write("The strings differ %s(lengths %d vs %d); "
253                              "a diff follows\n"
254                              % ('when stripped ' if strip else '',
255                                 len(a), len(b),
256                                 ))
257
258             from difflib import unified_diff
259             diff = unified_diff(a.splitlines(True),
260                                 b.splitlines(True),
261                                 'a', 'b')
262             for line in diff:
263                 sys.stderr.write(line)
264
265             self.fail(msg)
266
267
268 class LdbTestCase(TestCase):
269     """Trivial test case for running tests against a LDB."""
270
271     def setUp(self):
272         super(LdbTestCase, self).setUp()
273         self.tempfile = tempfile.NamedTemporaryFile(delete=False)
274         self.filename = self.tempfile.name
275         self.ldb = samba.Ldb(self.filename)
276
277     def set_modules(self, modules=[]):
278         """Change the modules for this Ldb."""
279         m = ldb.Message()
280         m.dn = ldb.Dn(self.ldb, "@MODULES")
281         m["@LIST"] = ",".join(modules)
282         self.ldb.add(m)
283         self.ldb = samba.Ldb(self.filename)
284
285
286 class TestCaseInTempDir(TestCase):
287
288     def setUp(self):
289         super(TestCaseInTempDir, self).setUp()
290         self.tempdir = tempfile.mkdtemp()
291         self.addCleanup(self._remove_tempdir)
292
293     def _remove_tempdir(self):
294         self.assertEquals([], os.listdir(self.tempdir))
295         os.rmdir(self.tempdir)
296         self.tempdir = None
297
298
299 def env_loadparm():
300     lp = param.LoadParm()
301     try:
302         lp.load(os.environ["SMB_CONF_PATH"])
303     except KeyError:
304         raise KeyError("SMB_CONF_PATH not set")
305     return lp
306
307
308 def env_get_var_value(var_name, allow_missing=False):
309     """Returns value for variable in os.environ
310
311     Function throws AssertionError if variable is defined.
312     Unit-test based python tests require certain input params
313     to be set in environment, otherwise they can't be run
314     """
315     if allow_missing:
316         if var_name not in os.environ.keys():
317             return None
318     assert var_name in os.environ.keys(), "Please supply %s in environment" % var_name
319     return os.environ[var_name]
320
321
322 cmdline_credentials = None
323
324
325 class RpcInterfaceTestCase(TestCase):
326     """DCE/RPC Test case."""
327
328
329 class ValidNetbiosNameTests(TestCase):
330
331     def test_valid(self):
332         self.assertTrue(samba.valid_netbios_name("FOO"))
333
334     def test_too_long(self):
335         self.assertFalse(samba.valid_netbios_name("FOO" * 10))
336
337     def test_invalid_characters(self):
338         self.assertFalse(samba.valid_netbios_name("*BLA"))
339
340
341 class BlackboxProcessError(Exception):
342     """This is raised when check_output() process returns a non-zero exit status
343
344     Exception instance should contain the exact exit code (S.returncode),
345     command line (S.cmd), process output (S.stdout) and process error stream
346     (S.stderr)
347     """
348
349     def __init__(self, returncode, cmd, stdout, stderr, msg=None):
350         self.returncode = returncode
351         self.cmd = cmd
352         self.stdout = stdout
353         self.stderr = stderr
354         self.msg = msg
355
356     def __str__(self):
357         s = ("Command '%s'; exit status %d; stdout: '%s'; stderr: '%s'" %
358              (self.cmd, self.returncode, self.stdout, self.stderr))
359         if self.msg is not None:
360             s = "%s; message: %s" % (s, self.msg)
361
362         return s
363
364
365 class BlackboxTestCase(TestCaseInTempDir):
366     """Base test case for blackbox tests."""
367
368     def _make_cmdline(self, line):
369         bindir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../../bin"))
370         parts = line.split(" ")
371         if os.path.exists(os.path.join(bindir, parts[0])):
372             parts[0] = os.path.join(bindir, parts[0])
373         line = " ".join(parts)
374         return line
375
376     def check_run(self, line, msg=None):
377         self.check_exit_code(line, 0, msg=msg)
378
379     def check_exit_code(self, line, expected, msg=None):
380         line = self._make_cmdline(line)
381         p = subprocess.Popen(line,
382                              stdout=subprocess.PIPE,
383                              stderr=subprocess.PIPE,
384                              shell=True)
385         stdoutdata, stderrdata = p.communicate()
386         retcode = p.returncode
387         if retcode != expected:
388             raise BlackboxProcessError(retcode,
389                                        line,
390                                        stdoutdata,
391                                        stderrdata,
392                                        msg)
393
394     def check_output(self, line):
395         line = self._make_cmdline(line)
396         p = subprocess.Popen(line, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True, close_fds=True)
397         stdoutdata, stderrdata = p.communicate()
398         retcode = p.returncode
399         if retcode:
400             raise BlackboxProcessError(retcode, line, stdoutdata, stderrdata)
401         return stdoutdata
402
403
404 def connect_samdb(samdb_url, lp=None, session_info=None, credentials=None,
405                   flags=0, ldb_options=None, ldap_only=False, global_schema=True):
406     """Create SamDB instance and connects to samdb_url database.
407
408     :param samdb_url: Url for database to connect to.
409     :param lp: Optional loadparm object
410     :param session_info: Optional session information
411     :param credentials: Optional credentials, defaults to anonymous.
412     :param flags: Optional LDB flags
413     :param ldap_only: If set, only remote LDAP connection will be created.
414     :param global_schema: Whether to use global schema.
415
416     Added value for tests is that we have a shorthand function
417     to make proper URL for ldb.connect() while using default
418     parameters for connection based on test environment
419     """
420     if not "://" in samdb_url:
421         if not ldap_only and os.path.isfile(samdb_url):
422             samdb_url = "tdb://%s" % samdb_url
423         else:
424             samdb_url = "ldap://%s" % samdb_url
425     # use 'paged_search' module when connecting remotely
426     if samdb_url.startswith("ldap://"):
427         ldb_options = ["modules:paged_searches"]
428     elif ldap_only:
429         raise AssertionError("Trying to connect to %s while remote "
430                              "connection is required" % samdb_url)
431
432     # set defaults for test environment
433     if lp is None:
434         lp = env_loadparm()
435     if session_info is None:
436         session_info = samba.auth.system_session(lp)
437     if credentials is None:
438         credentials = cmdline_credentials
439
440     return SamDB(url=samdb_url,
441                  lp=lp,
442                  session_info=session_info,
443                  credentials=credentials,
444                  flags=flags,
445                  options=ldb_options,
446                  global_schema=global_schema)
447
448
449 def connect_samdb_ex(samdb_url, lp=None, session_info=None, credentials=None,
450                      flags=0, ldb_options=None, ldap_only=False):
451     """Connects to samdb_url database
452
453     :param samdb_url: Url for database to connect to.
454     :param lp: Optional loadparm object
455     :param session_info: Optional session information
456     :param credentials: Optional credentials, defaults to anonymous.
457     :param flags: Optional LDB flags
458     :param ldap_only: If set, only remote LDAP connection will be created.
459     :return: (sam_db_connection, rootDse_record) tuple
460     """
461     sam_db = connect_samdb(samdb_url, lp, session_info, credentials,
462                            flags, ldb_options, ldap_only)
463     # fetch RootDse
464     res = sam_db.search(base="", expression="", scope=ldb.SCOPE_BASE,
465                         attrs=["*"])
466     return (sam_db, res[0])
467
468
469 def connect_samdb_env(env_url, env_username, env_password, lp=None):
470     """Connect to SamDB by getting URL and Credentials from environment
471
472     :param env_url: Environment variable name to get lsb url from
473     :param env_username: Username environment variable
474     :param env_password: Password environment variable
475     :return: sam_db_connection
476     """
477     samdb_url = env_get_var_value(env_url)
478     creds = credentials.Credentials()
479     if lp is None:
480         # guess Credentials parameters here. Otherwise workstation
481         # and domain fields are NULL and gencache code segfalts
482         lp = param.LoadParm()
483         creds.guess(lp)
484     creds.set_username(env_get_var_value(env_username))
485     creds.set_password(env_get_var_value(env_password))
486     return connect_samdb(samdb_url, credentials=creds, lp=lp)
487
488
489 def delete_force(samdb, dn, **kwargs):
490     try:
491         samdb.delete(dn, **kwargs)
492     except ldb.LdbError as error:
493         (num, errstr) = error.args
494         assert num == ldb.ERR_NO_SUCH_OBJECT, "ldb.delete() failed: %s" % errstr
495
496
497 def create_test_ou(samdb, name):
498     """Creates a unique OU for the test"""
499
500     # Add some randomness to the test OU. Replication between the testenvs is
501     # constantly happening in the background. Deletion of the last test's
502     # objects can be slow to replicate out. So the OU created by a previous
503     # testenv may still exist at the point that tests start on another testenv.
504     rand = randint(1, 10000000)
505     dn = ldb.Dn(samdb, "OU=%s%d,%s" % (name, rand, samdb.get_default_basedn()))
506     samdb.add({"dn": dn, "objectclass": "organizationalUnit"})
507     return dn