PEP8: fix E303: too many blank lines (2)
[sfrench/samba-autobuild/.git] / python / samba / tests / __init__.py
1 # Unix SMB/CIFS implementation.
2 # Copyright (C) Jelmer Vernooij <jelmer@samba.org> 2007-2010
3 # Copyright (C) Stefan Metzmacher 2014,2015
4 #
5 # This program is free software; you can redistribute it and/or modify
6 # it under the terms of the GNU General Public License as published by
7 # the Free Software Foundation; either version 3 of the License, or
8 # (at your option) any later version.
9 #
10 # This program is distributed in the hope that it will be useful,
11 # but WITHOUT ANY WARRANTY; without even the implied warranty of
12 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 # GNU General Public License for more details.
14 #
15 # You should have received a copy of the GNU General Public License
16 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
17 #
18
19 """Samba Python tests."""
20
21 import os
22 import tempfile
23 import ldb
24 import samba
25 from samba import param
26 from samba import credentials
27 from samba.credentials import Credentials
28 from samba import gensec
29 import socket
30 import struct
31 import subprocess
32 import sys
33 import tempfile
34 import unittest
35 import re
36 import samba.auth
37 import samba.dcerpc.base
38 from samba.compat import PY3, text_type
39 from samba.compat import string_types
40 from random import randint
41 try:
42     from samba.samdb import SamDB
43 except ImportError:
44     # We are built without samdb support,
45     # imitate it so that connect_samdb() can recover
46     def SamDB(*args, **kwargs):
47         return None
48
49 import samba.ndr
50 import samba.dcerpc.dcerpc
51 import samba.dcerpc.epmapper
52
53 try:
54     from unittest import SkipTest
55 except ImportError:
56     class SkipTest(Exception):
57         """Test skipped."""
58
59 HEXDUMP_FILTER = bytearray([x if ((len(repr(chr(x))) == 3) and (x < 127)) else ord('.') for x in range(256)])
60
61
62 class TestCase(unittest.TestCase):
63     """A Samba test case."""
64
65     def setUp(self):
66         super(TestCase, self).setUp()
67         test_debug_level = os.getenv("TEST_DEBUG_LEVEL")
68         if test_debug_level is not None:
69             test_debug_level = int(test_debug_level)
70             self._old_debug_level = samba.get_debug_level()
71             samba.set_debug_level(test_debug_level)
72             self.addCleanup(samba.set_debug_level, test_debug_level)
73
74     def get_loadparm(self):
75         return env_loadparm()
76
77     def get_credentials(self):
78         return cmdline_credentials
79
80     def get_creds_ccache_name(self):
81         creds = self.get_credentials()
82         ccache = creds.get_named_ccache(self.get_loadparm())
83         ccache_name = ccache.get_name()
84
85         return ccache_name
86
87     def hexdump(self, src):
88         N = 0
89         result = ''
90         is_string = isinstance(src, string_types)
91         while src:
92             ll = src[:8]
93             lr = src[8:16]
94             src = src[16:]
95             if is_string:
96                 hl = ' '.join(["%02X" % ord(x) for x in ll])
97                 hr = ' '.join(["%02X" % ord(x) for x in lr])
98                 ll = ll.translate(HEXDUMP_FILTER)
99                 lr = lr.translate(HEXDUMP_FILTER)
100             else:
101                 hl = ' '.join(["%02X" % x for x in ll])
102                 hr = ' '.join(["%02X" % x for x in lr])
103                 ll = ll.translate(HEXDUMP_FILTER).decode('utf8')
104                 lr = lr.translate(HEXDUMP_FILTER).decode('utf8')
105             result += "[%04X] %-*s  %-*s  %s %s\n" % (N, 8 * 3, hl, 8 * 3, hr, ll, lr)
106             N += 16
107         return result
108
109     def insta_creds(self, template=None, username=None, userpass=None, kerberos_state=None):
110
111         if template is None:
112             assert template is not None
113
114         if username is not None:
115             assert userpass is not None
116
117         if username is None:
118             assert userpass is None
119
120             username = template.get_username()
121             userpass = template.get_password()
122
123         if kerberos_state is None:
124             kerberos_state = template.get_kerberos_state()
125
126         # get a copy of the global creds or a the passed in creds
127         c = Credentials()
128         c.set_username(username)
129         c.set_password(userpass)
130         c.set_domain(template.get_domain())
131         c.set_realm(template.get_realm())
132         c.set_workstation(template.get_workstation())
133         c.set_gensec_features(c.get_gensec_features()
134                               | gensec.FEATURE_SEAL)
135         c.set_kerberos_state(kerberos_state)
136         return c
137
138     # These functions didn't exist before Python2.7:
139     if sys.version_info < (2, 7):
140         import warnings
141
142         def skipTest(self, reason):
143             raise SkipTest(reason)
144
145         def assertIn(self, member, container, msg=None):
146             self.assertTrue(member in container, msg)
147
148         def assertIs(self, a, b, msg=None):
149             self.assertTrue(a is b, msg)
150
151         def assertIsNot(self, a, b, msg=None):
152             self.assertTrue(a is not b, msg)
153
154         def assertIsNotNone(self, a, msg=None):
155             self.assertTrue(a is not None)
156
157         def assertIsInstance(self, a, b, msg=None):
158             self.assertTrue(isinstance(a, b), msg)
159
160         def assertIsNone(self, a, msg=None):
161             self.assertTrue(a is None, msg)
162
163         def assertGreater(self, a, b, msg=None):
164             self.assertTrue(a > b, msg)
165
166         def assertGreaterEqual(self, a, b, msg=None):
167             self.assertTrue(a >= b, msg)
168
169         def assertLess(self, a, b, msg=None):
170             self.assertTrue(a < b, msg)
171
172         def assertLessEqual(self, a, b, msg=None):
173             self.assertTrue(a <= b, msg)
174
175         def addCleanup(self, fn, *args, **kwargs):
176             self._cleanups = getattr(self, "_cleanups", []) + [
177                 (fn, args, kwargs)]
178
179         def assertRegexpMatches(self, text, regex, msg=None):
180             # PY3 note: Python 3 will never see this, but we use
181             # text_type for the benefit of linters.
182             if isinstance(regex, (str, text_type)):
183                 regex = re.compile(regex)
184             if not regex.search(text):
185                 self.fail(msg)
186
187         def _addSkip(self, result, reason):
188             addSkip = getattr(result, 'addSkip', None)
189             if addSkip is not None:
190                 addSkip(self, reason)
191             else:
192                 warnings.warn("TestResult has no addSkip method, skips not reported",
193                               RuntimeWarning, 2)
194                 result.addSuccess(self)
195
196         def run(self, result=None):
197             if result is None: result = self.defaultTestResult()
198             result.startTest(self)
199             testMethod = getattr(self, self._testMethodName)
200             try:
201                 try:
202                     self.setUp()
203                 except SkipTest as e:
204                     self._addSkip(result, str(e))
205                     return
206                 except KeyboardInterrupt:
207                     raise
208                 except:
209                     result.addError(self, self._exc_info())
210                     return
211
212                 ok = False
213                 try:
214                     testMethod()
215                     ok = True
216                 except SkipTest as e:
217                     self._addSkip(result, str(e))
218                     return
219                 except self.failureException:
220                     result.addFailure(self, self._exc_info())
221                 except KeyboardInterrupt:
222                     raise
223                 except:
224                     result.addError(self, self._exc_info())
225
226                 try:
227                     self.tearDown()
228                 except SkipTest as e:
229                     self._addSkip(result, str(e))
230                 except KeyboardInterrupt:
231                     raise
232                 except:
233                     result.addError(self, self._exc_info())
234                     ok = False
235
236                 for (fn, args, kwargs) in reversed(getattr(self, "_cleanups", [])):
237                     fn(*args, **kwargs)
238                 if ok: result.addSuccess(self)
239             finally:
240                 result.stopTest(self)
241
242     def assertStringsEqual(self, a, b, msg=None, strip=False):
243         """Assert equality between two strings and highlight any differences.
244         If strip is true, leading and trailing whitespace is ignored."""
245         if strip:
246             a = a.strip()
247             b = b.strip()
248
249         if a != b:
250             sys.stderr.write("The strings differ %s(lengths %d vs %d); "
251                              "a diff follows\n"
252                              % ('when stripped ' if strip else '',
253                                 len(a), len(b),
254                                 ))
255
256             from difflib import unified_diff
257             diff = unified_diff(a.splitlines(True),
258                                 b.splitlines(True),
259                                 'a', 'b')
260             for line in diff:
261                 sys.stderr.write(line)
262
263             self.fail(msg)
264
265
266 class LdbTestCase(TestCase):
267     """Trivial test case for running tests against a LDB."""
268
269     def setUp(self):
270         super(LdbTestCase, self).setUp()
271         self.tempfile = tempfile.NamedTemporaryFile(delete=False)
272         self.filename = self.tempfile.name
273         self.ldb = samba.Ldb(self.filename)
274
275     def set_modules(self, modules=[]):
276         """Change the modules for this Ldb."""
277         m = ldb.Message()
278         m.dn = ldb.Dn(self.ldb, "@MODULES")
279         m["@LIST"] = ",".join(modules)
280         self.ldb.add(m)
281         self.ldb = samba.Ldb(self.filename)
282
283
284 class TestCaseInTempDir(TestCase):
285
286     def setUp(self):
287         super(TestCaseInTempDir, self).setUp()
288         self.tempdir = tempfile.mkdtemp()
289         self.addCleanup(self._remove_tempdir)
290
291     def _remove_tempdir(self):
292         self.assertEquals([], os.listdir(self.tempdir))
293         os.rmdir(self.tempdir)
294         self.tempdir = None
295
296
297 def env_loadparm():
298     lp = param.LoadParm()
299     try:
300         lp.load(os.environ["SMB_CONF_PATH"])
301     except KeyError:
302         raise KeyError("SMB_CONF_PATH not set")
303     return lp
304
305
306 def env_get_var_value(var_name, allow_missing=False):
307     """Returns value for variable in os.environ
308
309     Function throws AssertionError if variable is defined.
310     Unit-test based python tests require certain input params
311     to be set in environment, otherwise they can't be run
312     """
313     if allow_missing:
314         if var_name not in os.environ.keys():
315             return None
316     assert var_name in os.environ.keys(), "Please supply %s in environment" % var_name
317     return os.environ[var_name]
318
319
320 cmdline_credentials = None
321
322
323 class RpcInterfaceTestCase(TestCase):
324     """DCE/RPC Test case."""
325
326
327 class ValidNetbiosNameTests(TestCase):
328
329     def test_valid(self):
330         self.assertTrue(samba.valid_netbios_name("FOO"))
331
332     def test_too_long(self):
333         self.assertFalse(samba.valid_netbios_name("FOO" * 10))
334
335     def test_invalid_characters(self):
336         self.assertFalse(samba.valid_netbios_name("*BLA"))
337
338
339 class BlackboxProcessError(Exception):
340     """This is raised when check_output() process returns a non-zero exit status
341
342     Exception instance should contain the exact exit code (S.returncode),
343     command line (S.cmd), process output (S.stdout) and process error stream
344     (S.stderr)
345     """
346
347     def __init__(self, returncode, cmd, stdout, stderr, msg=None):
348         self.returncode = returncode
349         self.cmd = cmd
350         self.stdout = stdout
351         self.stderr = stderr
352         self.msg = msg
353
354     def __str__(self):
355         s = ("Command '%s'; exit status %d; stdout: '%s'; stderr: '%s'" %
356              (self.cmd, self.returncode, self.stdout, self.stderr))
357         if self.msg is not None:
358             s = "%s; message: %s" % (s, self.msg)
359
360         return s
361
362
363 class BlackboxTestCase(TestCaseInTempDir):
364     """Base test case for blackbox tests."""
365
366     def _make_cmdline(self, line):
367         bindir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../../bin"))
368         parts = line.split(" ")
369         if os.path.exists(os.path.join(bindir, parts[0])):
370             parts[0] = os.path.join(bindir, parts[0])
371         line = " ".join(parts)
372         return line
373
374     def check_run(self, line, msg=None):
375         self.check_exit_code(line, 0, msg=msg)
376
377     def check_exit_code(self, line, expected, msg=None):
378         line = self._make_cmdline(line)
379         p = subprocess.Popen(line,
380                              stdout=subprocess.PIPE,
381                              stderr=subprocess.PIPE,
382                              shell=True)
383         stdoutdata, stderrdata = p.communicate()
384         retcode = p.returncode
385         if retcode != expected:
386             raise BlackboxProcessError(retcode,
387                                        line,
388                                        stdoutdata,
389                                        stderrdata,
390                                        msg)
391
392     def check_output(self, line):
393         line = self._make_cmdline(line)
394         p = subprocess.Popen(line, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True, close_fds=True)
395         stdoutdata, stderrdata = p.communicate()
396         retcode = p.returncode
397         if retcode:
398             raise BlackboxProcessError(retcode, line, stdoutdata, stderrdata)
399         return stdoutdata
400
401
402 def connect_samdb(samdb_url, lp=None, session_info=None, credentials=None,
403                   flags=0, ldb_options=None, ldap_only=False, global_schema=True):
404     """Create SamDB instance and connects to samdb_url database.
405
406     :param samdb_url: Url for database to connect to.
407     :param lp: Optional loadparm object
408     :param session_info: Optional session information
409     :param credentials: Optional credentials, defaults to anonymous.
410     :param flags: Optional LDB flags
411     :param ldap_only: If set, only remote LDAP connection will be created.
412     :param global_schema: Whether to use global schema.
413
414     Added value for tests is that we have a shorthand function
415     to make proper URL for ldb.connect() while using default
416     parameters for connection based on test environment
417     """
418     if not "://" in samdb_url:
419         if not ldap_only and os.path.isfile(samdb_url):
420             samdb_url = "tdb://%s" % samdb_url
421         else:
422             samdb_url = "ldap://%s" % samdb_url
423     # use 'paged_search' module when connecting remotely
424     if samdb_url.startswith("ldap://"):
425         ldb_options = ["modules:paged_searches"]
426     elif ldap_only:
427         raise AssertionError("Trying to connect to %s while remote "
428                              "connection is required" % samdb_url)
429
430     # set defaults for test environment
431     if lp is None:
432         lp = env_loadparm()
433     if session_info is None:
434         session_info = samba.auth.system_session(lp)
435     if credentials is None:
436         credentials = cmdline_credentials
437
438     return SamDB(url=samdb_url,
439                  lp=lp,
440                  session_info=session_info,
441                  credentials=credentials,
442                  flags=flags,
443                  options=ldb_options,
444                  global_schema=global_schema)
445
446
447 def connect_samdb_ex(samdb_url, lp=None, session_info=None, credentials=None,
448                      flags=0, ldb_options=None, ldap_only=False):
449     """Connects to samdb_url database
450
451     :param samdb_url: Url for database to connect to.
452     :param lp: Optional loadparm object
453     :param session_info: Optional session information
454     :param credentials: Optional credentials, defaults to anonymous.
455     :param flags: Optional LDB flags
456     :param ldap_only: If set, only remote LDAP connection will be created.
457     :return: (sam_db_connection, rootDse_record) tuple
458     """
459     sam_db = connect_samdb(samdb_url, lp, session_info, credentials,
460                            flags, ldb_options, ldap_only)
461     # fetch RootDse
462     res = sam_db.search(base="", expression="", scope=ldb.SCOPE_BASE,
463                         attrs=["*"])
464     return (sam_db, res[0])
465
466
467 def connect_samdb_env(env_url, env_username, env_password, lp=None):
468     """Connect to SamDB by getting URL and Credentials from environment
469
470     :param env_url: Environment variable name to get lsb url from
471     :param env_username: Username environment variable
472     :param env_password: Password environment variable
473     :return: sam_db_connection
474     """
475     samdb_url = env_get_var_value(env_url)
476     creds = credentials.Credentials()
477     if lp is None:
478         # guess Credentials parameters here. Otherwise workstation
479         # and domain fields are NULL and gencache code segfalts
480         lp = param.LoadParm()
481         creds.guess(lp)
482     creds.set_username(env_get_var_value(env_username))
483     creds.set_password(env_get_var_value(env_password))
484     return connect_samdb(samdb_url, credentials=creds, lp=lp)
485
486
487 def delete_force(samdb, dn, **kwargs):
488     try:
489         samdb.delete(dn, **kwargs)
490     except ldb.LdbError as error:
491         (num, errstr) = error.args
492         assert num == ldb.ERR_NO_SUCH_OBJECT, "ldb.delete() failed: %s" % errstr
493
494
495 def create_test_ou(samdb, name):
496     """Creates a unique OU for the test"""
497
498     # Add some randomness to the test OU. Replication between the testenvs is
499     # constantly happening in the background. Deletion of the last test's
500     # objects can be slow to replicate out. So the OU created by a previous
501     # testenv may still exist at the point that tests start on another testenv.
502     rand = randint(1, 10000000)
503     dn = ldb.Dn(samdb, "OU=%s%d,%s" % (name, rand, samdb.get_default_basedn()))
504     samdb.add({"dn": dn, "objectclass": "organizationalUnit"})
505     return dn