1 # Unix SMB/CIFS implementation.
2 # Copyright (C) Jelmer Vernooij <jelmer@samba.org> 2007-2010
3 # Copyright (C) Stefan Metzmacher 2014,2015
5 # This program is free software; you can redistribute it and/or modify
6 # it under the terms of the GNU General Public License as published by
7 # the Free Software Foundation; either version 3 of the License, or
8 # (at your option) any later version.
10 # This program is distributed in the hope that it will be useful,
11 # but WITHOUT ANY WARRANTY; without even the implied warranty of
12 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 # GNU General Public License for more details.
15 # You should have received a copy of the GNU General Public License
16 # along with this program. If not, see <http://www.gnu.org/licenses/>.
19 """Samba Python tests."""
25 from samba import param
26 from samba import credentials
27 from samba.credentials import Credentials
28 from samba import gensec
37 import samba.dcerpc.base
38 from samba.compat import PY3, text_type
39 from samba.compat import string_types
40 from random import randint
42 from samba.samdb import SamDB
44 # We are built without samdb support,
45 # imitate it so that connect_samdb() can recover
46 def SamDB(*args, **kwargs):
50 import samba.dcerpc.dcerpc
51 import samba.dcerpc.epmapper
54 from unittest import SkipTest
56 class SkipTest(Exception):
59 HEXDUMP_FILTER = bytearray([x if ((len(repr(chr(x))) == 3) and (x < 127)) else ord('.') for x in range(256)])
61 class TestCase(unittest.TestCase):
62 """A Samba test case."""
65 super(TestCase, self).setUp()
66 test_debug_level = os.getenv("TEST_DEBUG_LEVEL")
67 if test_debug_level is not None:
68 test_debug_level = int(test_debug_level)
69 self._old_debug_level = samba.get_debug_level()
70 samba.set_debug_level(test_debug_level)
71 self.addCleanup(samba.set_debug_level, test_debug_level)
73 def get_loadparm(self):
76 def get_credentials(self):
77 return cmdline_credentials
79 def get_creds_ccache_name(self):
80 creds = self.get_credentials()
81 ccache = creds.get_named_ccache(self.get_loadparm())
82 ccache_name = ccache.get_name()
86 def hexdump(self, src):
89 is_string = isinstance(src, string_types)
95 hl = ' '.join(["%02X" % ord(x) for x in ll])
96 hr = ' '.join(["%02X" % ord(x) for x in lr])
97 ll = ll.translate(HEXDUMP_FILTER)
98 lr = lr.translate(HEXDUMP_FILTER)
100 hl = ' '.join(["%02X" % x for x in ll])
101 hr = ' '.join(["%02X" % x for x in lr])
102 ll = ll.translate(HEXDUMP_FILTER).decode('utf8')
103 lr = lr.translate(HEXDUMP_FILTER).decode('utf8')
104 result += "[%04X] %-*s %-*s %s %s\n" % (N, 8 * 3, hl, 8 * 3, hr, ll, lr)
108 def insta_creds(self, template=None, username=None, userpass=None, kerberos_state=None):
111 assert template is not None
113 if username is not None:
114 assert userpass is not None
117 assert userpass is None
119 username = template.get_username()
120 userpass = template.get_password()
122 if kerberos_state is None:
123 kerberos_state = template.get_kerberos_state()
125 # get a copy of the global creds or a the passed in creds
127 c.set_username(username)
128 c.set_password(userpass)
129 c.set_domain(template.get_domain())
130 c.set_realm(template.get_realm())
131 c.set_workstation(template.get_workstation())
132 c.set_gensec_features(c.get_gensec_features()
133 | gensec.FEATURE_SEAL)
134 c.set_kerberos_state(kerberos_state)
139 # These functions didn't exist before Python2.7:
140 if sys.version_info < (2, 7):
143 def skipTest(self, reason):
144 raise SkipTest(reason)
146 def assertIn(self, member, container, msg=None):
147 self.assertTrue(member in container, msg)
149 def assertIs(self, a, b, msg=None):
150 self.assertTrue(a is b, msg)
152 def assertIsNot(self, a, b, msg=None):
153 self.assertTrue(a is not b, msg)
155 def assertIsNotNone(self, a, msg=None):
156 self.assertTrue(a is not None)
158 def assertIsInstance(self, a, b, msg=None):
159 self.assertTrue(isinstance(a, b), msg)
161 def assertIsNone(self, a, msg=None):
162 self.assertTrue(a is None, msg)
164 def assertGreater(self, a, b, msg=None):
165 self.assertTrue(a > b, msg)
167 def assertGreaterEqual(self, a, b, msg=None):
168 self.assertTrue(a >= b, msg)
170 def assertLess(self, a, b, msg=None):
171 self.assertTrue(a < b, msg)
173 def assertLessEqual(self, a, b, msg=None):
174 self.assertTrue(a <= b, msg)
176 def addCleanup(self, fn, *args, **kwargs):
177 self._cleanups = getattr(self, "_cleanups", []) + [
180 def assertRegexpMatches(self, text, regex, msg=None):
181 # PY3 note: Python 3 will never see this, but we use
182 # text_type for the benefit of linters.
183 if isinstance(regex, (str, text_type)):
184 regex = re.compile(regex)
185 if not regex.search(text):
188 def _addSkip(self, result, reason):
189 addSkip = getattr(result, 'addSkip', None)
190 if addSkip is not None:
191 addSkip(self, reason)
193 warnings.warn("TestResult has no addSkip method, skips not reported",
195 result.addSuccess(self)
197 def run(self, result=None):
198 if result is None: result = self.defaultTestResult()
199 result.startTest(self)
200 testMethod = getattr(self, self._testMethodName)
204 except SkipTest as e:
205 self._addSkip(result, str(e))
207 except KeyboardInterrupt:
210 result.addError(self, self._exc_info())
217 except SkipTest as e:
218 self._addSkip(result, str(e))
220 except self.failureException:
221 result.addFailure(self, self._exc_info())
222 except KeyboardInterrupt:
225 result.addError(self, self._exc_info())
229 except SkipTest as e:
230 self._addSkip(result, str(e))
231 except KeyboardInterrupt:
234 result.addError(self, self._exc_info())
237 for (fn, args, kwargs) in reversed(getattr(self, "_cleanups", [])):
239 if ok: result.addSuccess(self)
241 result.stopTest(self)
243 def assertStringsEqual(self, a, b, msg=None, strip=False):
244 """Assert equality between two strings and highlight any differences.
245 If strip is true, leading and trailing whitespace is ignored."""
251 sys.stderr.write("The strings differ %s(lengths %d vs %d); "
253 % ('when stripped ' if strip else '',
257 from difflib import unified_diff
258 diff = unified_diff(a.splitlines(True),
262 sys.stderr.write(line)
267 class LdbTestCase(TestCase):
268 """Trivial test case for running tests against a LDB."""
271 super(LdbTestCase, self).setUp()
272 self.tempfile = tempfile.NamedTemporaryFile(delete=False)
273 self.filename = self.tempfile.name
274 self.ldb = samba.Ldb(self.filename)
276 def set_modules(self, modules=[]):
277 """Change the modules for this Ldb."""
279 m.dn = ldb.Dn(self.ldb, "@MODULES")
280 m["@LIST"] = ",".join(modules)
282 self.ldb = samba.Ldb(self.filename)
285 class TestCaseInTempDir(TestCase):
288 super(TestCaseInTempDir, self).setUp()
289 self.tempdir = tempfile.mkdtemp()
290 self.addCleanup(self._remove_tempdir)
292 def _remove_tempdir(self):
293 self.assertEquals([], os.listdir(self.tempdir))
294 os.rmdir(self.tempdir)
299 lp = param.LoadParm()
301 lp.load(os.environ["SMB_CONF_PATH"])
303 raise KeyError("SMB_CONF_PATH not set")
307 def env_get_var_value(var_name, allow_missing=False):
308 """Returns value for variable in os.environ
310 Function throws AssertionError if variable is defined.
311 Unit-test based python tests require certain input params
312 to be set in environment, otherwise they can't be run
315 if var_name not in os.environ.keys():
317 assert var_name in os.environ.keys(), "Please supply %s in environment" % var_name
318 return os.environ[var_name]
321 cmdline_credentials = None
323 class RpcInterfaceTestCase(TestCase):
324 """DCE/RPC Test case."""
327 class ValidNetbiosNameTests(TestCase):
329 def test_valid(self):
330 self.assertTrue(samba.valid_netbios_name("FOO"))
332 def test_too_long(self):
333 self.assertFalse(samba.valid_netbios_name("FOO" * 10))
335 def test_invalid_characters(self):
336 self.assertFalse(samba.valid_netbios_name("*BLA"))
339 class BlackboxProcessError(Exception):
340 """This is raised when check_output() process returns a non-zero exit status
342 Exception instance should contain the exact exit code (S.returncode),
343 command line (S.cmd), process output (S.stdout) and process error stream
347 def __init__(self, returncode, cmd, stdout, stderr, msg=None):
348 self.returncode = returncode
355 s = ("Command '%s'; exit status %d; stdout: '%s'; stderr: '%s'" %
356 (self.cmd, self.returncode, self.stdout, self.stderr))
357 if self.msg is not None:
358 s = "%s; message: %s" % (s, self.msg)
362 class BlackboxTestCase(TestCaseInTempDir):
363 """Base test case for blackbox tests."""
365 def _make_cmdline(self, line):
366 bindir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../../bin"))
367 parts = line.split(" ")
368 if os.path.exists(os.path.join(bindir, parts[0])):
369 parts[0] = os.path.join(bindir, parts[0])
370 line = " ".join(parts)
373 def check_run(self, line, msg=None):
374 self.check_exit_code(line, 0, msg=msg)
376 def check_exit_code(self, line, expected, msg=None):
377 line = self._make_cmdline(line)
378 p = subprocess.Popen(line,
379 stdout=subprocess.PIPE,
380 stderr=subprocess.PIPE,
382 stdoutdata, stderrdata = p.communicate()
383 retcode = p.returncode
384 if retcode != expected:
385 raise BlackboxProcessError(retcode,
391 def check_output(self, line):
392 line = self._make_cmdline(line)
393 p = subprocess.Popen(line, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True, close_fds=True)
394 stdoutdata, stderrdata = p.communicate()
395 retcode = p.returncode
397 raise BlackboxProcessError(retcode, line, stdoutdata, stderrdata)
401 def connect_samdb(samdb_url, lp=None, session_info=None, credentials=None,
402 flags=0, ldb_options=None, ldap_only=False, global_schema=True):
403 """Create SamDB instance and connects to samdb_url database.
405 :param samdb_url: Url for database to connect to.
406 :param lp: Optional loadparm object
407 :param session_info: Optional session information
408 :param credentials: Optional credentials, defaults to anonymous.
409 :param flags: Optional LDB flags
410 :param ldap_only: If set, only remote LDAP connection will be created.
411 :param global_schema: Whether to use global schema.
413 Added value for tests is that we have a shorthand function
414 to make proper URL for ldb.connect() while using default
415 parameters for connection based on test environment
417 if not "://" in samdb_url:
418 if not ldap_only and os.path.isfile(samdb_url):
419 samdb_url = "tdb://%s" % samdb_url
421 samdb_url = "ldap://%s" % samdb_url
422 # use 'paged_search' module when connecting remotely
423 if samdb_url.startswith("ldap://"):
424 ldb_options = ["modules:paged_searches"]
426 raise AssertionError("Trying to connect to %s while remote "
427 "connection is required" % samdb_url)
429 # set defaults for test environment
432 if session_info is None:
433 session_info = samba.auth.system_session(lp)
434 if credentials is None:
435 credentials = cmdline_credentials
437 return SamDB(url=samdb_url,
439 session_info=session_info,
440 credentials=credentials,
443 global_schema=global_schema)
446 def connect_samdb_ex(samdb_url, lp=None, session_info=None, credentials=None,
447 flags=0, ldb_options=None, ldap_only=False):
448 """Connects to samdb_url database
450 :param samdb_url: Url for database to connect to.
451 :param lp: Optional loadparm object
452 :param session_info: Optional session information
453 :param credentials: Optional credentials, defaults to anonymous.
454 :param flags: Optional LDB flags
455 :param ldap_only: If set, only remote LDAP connection will be created.
456 :return: (sam_db_connection, rootDse_record) tuple
458 sam_db = connect_samdb(samdb_url, lp, session_info, credentials,
459 flags, ldb_options, ldap_only)
461 res = sam_db.search(base="", expression="", scope=ldb.SCOPE_BASE,
463 return (sam_db, res[0])
466 def connect_samdb_env(env_url, env_username, env_password, lp=None):
467 """Connect to SamDB by getting URL and Credentials from environment
469 :param env_url: Environment variable name to get lsb url from
470 :param env_username: Username environment variable
471 :param env_password: Password environment variable
472 :return: sam_db_connection
474 samdb_url = env_get_var_value(env_url)
475 creds = credentials.Credentials()
477 # guess Credentials parameters here. Otherwise workstation
478 # and domain fields are NULL and gencache code segfalts
479 lp = param.LoadParm()
481 creds.set_username(env_get_var_value(env_username))
482 creds.set_password(env_get_var_value(env_password))
483 return connect_samdb(samdb_url, credentials=creds, lp=lp)
486 def delete_force(samdb, dn, **kwargs):
488 samdb.delete(dn, **kwargs)
489 except ldb.LdbError as error:
490 (num, errstr) = error.args
491 assert num == ldb.ERR_NO_SUCH_OBJECT, "ldb.delete() failed: %s" % errstr
493 def create_test_ou(samdb, name):
494 """Creates a unique OU for the test"""
496 # Add some randomness to the test OU. Replication between the testenvs is
497 # constantly happening in the background. Deletion of the last test's
498 # objects can be slow to replicate out. So the OU created by a previous
499 # testenv may still exist at the point that tests start on another testenv.
500 rand = randint(1, 10000000)
501 dn = ldb.Dn(samdb, "OU=%s%d,%s" % (name, rand, samdb.get_default_basedn()))
502 samdb.add({"dn": dn, "objectclass": "organizationalUnit"})