1 # Unix SMB/CIFS implementation.
2 # Copyright (C) Jelmer Vernooij <jelmer@samba.org> 2007-2010
3 # Copyright (C) Stefan Metzmacher 2014,2015
5 # This program is free software; you can redistribute it and/or modify
6 # it under the terms of the GNU General Public License as published by
7 # the Free Software Foundation; either version 3 of the License, or
8 # (at your option) any later version.
10 # This program is distributed in the hope that it will be useful,
11 # but WITHOUT ANY WARRANTY; without even the implied warranty of
12 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 # GNU General Public License for more details.
15 # You should have received a copy of the GNU General Public License
16 # along with this program. If not, see <http://www.gnu.org/licenses/>.
19 """Samba Python tests."""
25 from samba import param
26 from samba import credentials
27 from samba.credentials import Credentials
28 from samba import gensec
37 import samba.dcerpc.base
38 from samba.compat import PY3, text_type
39 from samba.compat import string_types
40 from random import randint
42 from samba.samdb import SamDB
44 # We are built without samdb support,
45 # imitate it so that connect_samdb() can recover
46 def SamDB(*args, **kwargs):
50 import samba.dcerpc.dcerpc
51 import samba.dcerpc.epmapper
54 from unittest import SkipTest
56 class SkipTest(Exception):
59 HEXDUMP_FILTER = bytearray([x if ((len(repr(chr(x))) == 3) and (x < 127)) else ord('.') for x in range(256)])
62 class TestCase(unittest.TestCase):
63 """A Samba test case."""
66 super(TestCase, self).setUp()
67 test_debug_level = os.getenv("TEST_DEBUG_LEVEL")
68 if test_debug_level is not None:
69 test_debug_level = int(test_debug_level)
70 self._old_debug_level = samba.get_debug_level()
71 samba.set_debug_level(test_debug_level)
72 self.addCleanup(samba.set_debug_level, test_debug_level)
74 def get_loadparm(self):
77 def get_credentials(self):
78 return cmdline_credentials
80 def get_creds_ccache_name(self):
81 creds = self.get_credentials()
82 ccache = creds.get_named_ccache(self.get_loadparm())
83 ccache_name = ccache.get_name()
87 def hexdump(self, src):
90 is_string = isinstance(src, string_types)
96 hl = ' '.join(["%02X" % ord(x) for x in ll])
97 hr = ' '.join(["%02X" % ord(x) for x in lr])
98 ll = ll.translate(HEXDUMP_FILTER)
99 lr = lr.translate(HEXDUMP_FILTER)
101 hl = ' '.join(["%02X" % x for x in ll])
102 hr = ' '.join(["%02X" % x for x in lr])
103 ll = ll.translate(HEXDUMP_FILTER).decode('utf8')
104 lr = lr.translate(HEXDUMP_FILTER).decode('utf8')
105 result += "[%04X] %-*s %-*s %s %s\n" % (N, 8 * 3, hl, 8 * 3, hr, ll, lr)
109 def insta_creds(self, template=None, username=None, userpass=None, kerberos_state=None):
112 assert template is not None
114 if username is not None:
115 assert userpass is not None
118 assert userpass is None
120 username = template.get_username()
121 userpass = template.get_password()
123 if kerberos_state is None:
124 kerberos_state = template.get_kerberos_state()
126 # get a copy of the global creds or a the passed in creds
128 c.set_username(username)
129 c.set_password(userpass)
130 c.set_domain(template.get_domain())
131 c.set_realm(template.get_realm())
132 c.set_workstation(template.get_workstation())
133 c.set_gensec_features(c.get_gensec_features()
134 | gensec.FEATURE_SEAL)
135 c.set_kerberos_state(kerberos_state)
138 # These functions didn't exist before Python2.7:
139 if sys.version_info < (2, 7):
142 def skipTest(self, reason):
143 raise SkipTest(reason)
145 def assertIn(self, member, container, msg=None):
146 self.assertTrue(member in container, msg)
148 def assertIs(self, a, b, msg=None):
149 self.assertTrue(a is b, msg)
151 def assertIsNot(self, a, b, msg=None):
152 self.assertTrue(a is not b, msg)
154 def assertIsNotNone(self, a, msg=None):
155 self.assertTrue(a is not None)
157 def assertIsInstance(self, a, b, msg=None):
158 self.assertTrue(isinstance(a, b), msg)
160 def assertIsNone(self, a, msg=None):
161 self.assertTrue(a is None, msg)
163 def assertGreater(self, a, b, msg=None):
164 self.assertTrue(a > b, msg)
166 def assertGreaterEqual(self, a, b, msg=None):
167 self.assertTrue(a >= b, msg)
169 def assertLess(self, a, b, msg=None):
170 self.assertTrue(a < b, msg)
172 def assertLessEqual(self, a, b, msg=None):
173 self.assertTrue(a <= b, msg)
175 def addCleanup(self, fn, *args, **kwargs):
176 self._cleanups = getattr(self, "_cleanups", []) + [
179 def assertRegexpMatches(self, text, regex, msg=None):
180 # PY3 note: Python 3 will never see this, but we use
181 # text_type for the benefit of linters.
182 if isinstance(regex, (str, text_type)):
183 regex = re.compile(regex)
184 if not regex.search(text):
187 def _addSkip(self, result, reason):
188 addSkip = getattr(result, 'addSkip', None)
189 if addSkip is not None:
190 addSkip(self, reason)
192 warnings.warn("TestResult has no addSkip method, skips not reported",
194 result.addSuccess(self)
196 def run(self, result=None):
197 if result is None: result = self.defaultTestResult()
198 result.startTest(self)
199 testMethod = getattr(self, self._testMethodName)
203 except SkipTest as e:
204 self._addSkip(result, str(e))
206 except KeyboardInterrupt:
209 result.addError(self, self._exc_info())
216 except SkipTest as e:
217 self._addSkip(result, str(e))
219 except self.failureException:
220 result.addFailure(self, self._exc_info())
221 except KeyboardInterrupt:
224 result.addError(self, self._exc_info())
228 except SkipTest as e:
229 self._addSkip(result, str(e))
230 except KeyboardInterrupt:
233 result.addError(self, self._exc_info())
236 for (fn, args, kwargs) in reversed(getattr(self, "_cleanups", [])):
238 if ok: result.addSuccess(self)
240 result.stopTest(self)
242 def assertStringsEqual(self, a, b, msg=None, strip=False):
243 """Assert equality between two strings and highlight any differences.
244 If strip is true, leading and trailing whitespace is ignored."""
250 sys.stderr.write("The strings differ %s(lengths %d vs %d); "
252 % ('when stripped ' if strip else '',
256 from difflib import unified_diff
257 diff = unified_diff(a.splitlines(True),
261 sys.stderr.write(line)
266 class LdbTestCase(TestCase):
267 """Trivial test case for running tests against a LDB."""
270 super(LdbTestCase, self).setUp()
271 self.tempfile = tempfile.NamedTemporaryFile(delete=False)
272 self.filename = self.tempfile.name
273 self.ldb = samba.Ldb(self.filename)
275 def set_modules(self, modules=[]):
276 """Change the modules for this Ldb."""
278 m.dn = ldb.Dn(self.ldb, "@MODULES")
279 m["@LIST"] = ",".join(modules)
281 self.ldb = samba.Ldb(self.filename)
284 class TestCaseInTempDir(TestCase):
287 super(TestCaseInTempDir, self).setUp()
288 self.tempdir = tempfile.mkdtemp()
289 self.addCleanup(self._remove_tempdir)
291 def _remove_tempdir(self):
292 self.assertEquals([], os.listdir(self.tempdir))
293 os.rmdir(self.tempdir)
298 lp = param.LoadParm()
300 lp.load(os.environ["SMB_CONF_PATH"])
302 raise KeyError("SMB_CONF_PATH not set")
306 def env_get_var_value(var_name, allow_missing=False):
307 """Returns value for variable in os.environ
309 Function throws AssertionError if variable is defined.
310 Unit-test based python tests require certain input params
311 to be set in environment, otherwise they can't be run
314 if var_name not in os.environ.keys():
316 assert var_name in os.environ.keys(), "Please supply %s in environment" % var_name
317 return os.environ[var_name]
320 cmdline_credentials = None
323 class RpcInterfaceTestCase(TestCase):
324 """DCE/RPC Test case."""
327 class ValidNetbiosNameTests(TestCase):
329 def test_valid(self):
330 self.assertTrue(samba.valid_netbios_name("FOO"))
332 def test_too_long(self):
333 self.assertFalse(samba.valid_netbios_name("FOO" * 10))
335 def test_invalid_characters(self):
336 self.assertFalse(samba.valid_netbios_name("*BLA"))
339 class BlackboxProcessError(Exception):
340 """This is raised when check_output() process returns a non-zero exit status
342 Exception instance should contain the exact exit code (S.returncode),
343 command line (S.cmd), process output (S.stdout) and process error stream
347 def __init__(self, returncode, cmd, stdout, stderr, msg=None):
348 self.returncode = returncode
355 s = ("Command '%s'; exit status %d; stdout: '%s'; stderr: '%s'" %
356 (self.cmd, self.returncode, self.stdout, self.stderr))
357 if self.msg is not None:
358 s = "%s; message: %s" % (s, self.msg)
363 class BlackboxTestCase(TestCaseInTempDir):
364 """Base test case for blackbox tests."""
366 def _make_cmdline(self, line):
367 bindir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../../bin"))
368 parts = line.split(" ")
369 if os.path.exists(os.path.join(bindir, parts[0])):
370 parts[0] = os.path.join(bindir, parts[0])
371 line = " ".join(parts)
374 def check_run(self, line, msg=None):
375 self.check_exit_code(line, 0, msg=msg)
377 def check_exit_code(self, line, expected, msg=None):
378 line = self._make_cmdline(line)
379 p = subprocess.Popen(line,
380 stdout=subprocess.PIPE,
381 stderr=subprocess.PIPE,
383 stdoutdata, stderrdata = p.communicate()
384 retcode = p.returncode
385 if retcode != expected:
386 raise BlackboxProcessError(retcode,
392 def check_output(self, line):
393 line = self._make_cmdline(line)
394 p = subprocess.Popen(line, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True, close_fds=True)
395 stdoutdata, stderrdata = p.communicate()
396 retcode = p.returncode
398 raise BlackboxProcessError(retcode, line, stdoutdata, stderrdata)
402 def connect_samdb(samdb_url, lp=None, session_info=None, credentials=None,
403 flags=0, ldb_options=None, ldap_only=False, global_schema=True):
404 """Create SamDB instance and connects to samdb_url database.
406 :param samdb_url: Url for database to connect to.
407 :param lp: Optional loadparm object
408 :param session_info: Optional session information
409 :param credentials: Optional credentials, defaults to anonymous.
410 :param flags: Optional LDB flags
411 :param ldap_only: If set, only remote LDAP connection will be created.
412 :param global_schema: Whether to use global schema.
414 Added value for tests is that we have a shorthand function
415 to make proper URL for ldb.connect() while using default
416 parameters for connection based on test environment
418 if not "://" in samdb_url:
419 if not ldap_only and os.path.isfile(samdb_url):
420 samdb_url = "tdb://%s" % samdb_url
422 samdb_url = "ldap://%s" % samdb_url
423 # use 'paged_search' module when connecting remotely
424 if samdb_url.startswith("ldap://"):
425 ldb_options = ["modules:paged_searches"]
427 raise AssertionError("Trying to connect to %s while remote "
428 "connection is required" % samdb_url)
430 # set defaults for test environment
433 if session_info is None:
434 session_info = samba.auth.system_session(lp)
435 if credentials is None:
436 credentials = cmdline_credentials
438 return SamDB(url=samdb_url,
440 session_info=session_info,
441 credentials=credentials,
444 global_schema=global_schema)
447 def connect_samdb_ex(samdb_url, lp=None, session_info=None, credentials=None,
448 flags=0, ldb_options=None, ldap_only=False):
449 """Connects to samdb_url database
451 :param samdb_url: Url for database to connect to.
452 :param lp: Optional loadparm object
453 :param session_info: Optional session information
454 :param credentials: Optional credentials, defaults to anonymous.
455 :param flags: Optional LDB flags
456 :param ldap_only: If set, only remote LDAP connection will be created.
457 :return: (sam_db_connection, rootDse_record) tuple
459 sam_db = connect_samdb(samdb_url, lp, session_info, credentials,
460 flags, ldb_options, ldap_only)
462 res = sam_db.search(base="", expression="", scope=ldb.SCOPE_BASE,
464 return (sam_db, res[0])
467 def connect_samdb_env(env_url, env_username, env_password, lp=None):
468 """Connect to SamDB by getting URL and Credentials from environment
470 :param env_url: Environment variable name to get lsb url from
471 :param env_username: Username environment variable
472 :param env_password: Password environment variable
473 :return: sam_db_connection
475 samdb_url = env_get_var_value(env_url)
476 creds = credentials.Credentials()
478 # guess Credentials parameters here. Otherwise workstation
479 # and domain fields are NULL and gencache code segfalts
480 lp = param.LoadParm()
482 creds.set_username(env_get_var_value(env_username))
483 creds.set_password(env_get_var_value(env_password))
484 return connect_samdb(samdb_url, credentials=creds, lp=lp)
487 def delete_force(samdb, dn, **kwargs):
489 samdb.delete(dn, **kwargs)
490 except ldb.LdbError as error:
491 (num, errstr) = error.args
492 assert num == ldb.ERR_NO_SUCH_OBJECT, "ldb.delete() failed: %s" % errstr
495 def create_test_ou(samdb, name):
496 """Creates a unique OU for the test"""
498 # Add some randomness to the test OU. Replication between the testenvs is
499 # constantly happening in the background. Deletion of the last test's
500 # objects can be slow to replicate out. So the OU created by a previous
501 # testenv may still exist at the point that tests start on another testenv.
502 rand = randint(1, 10000000)
503 dn = ldb.Dn(samdb, "OU=%s%d,%s" % (name, rand, samdb.get_default_basedn()))
504 samdb.add({"dn": dn, "objectclass": "organizationalUnit"})