1 # Unix SMB/CIFS implementation.
2 # Copyright (C) Jelmer Vernooij <jelmer@samba.org> 2007-2010
3 # Copyright (C) Stefan Metzmacher 2014,2015
5 # This program is free software; you can redistribute it and/or modify
6 # it under the terms of the GNU General Public License as published by
7 # the Free Software Foundation; either version 3 of the License, or
8 # (at your option) any later version.
10 # This program is distributed in the hope that it will be useful,
11 # but WITHOUT ANY WARRANTY; without even the implied warranty of
12 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 # GNU General Public License for more details.
15 # You should have received a copy of the GNU General Public License
16 # along with this program. If not, see <http://www.gnu.org/licenses/>.
19 """Samba Python tests."""
24 from samba import param
25 from samba import credentials
26 from samba.credentials import Credentials
27 from samba import gensec
36 import samba.dcerpc.base
37 from samba.compat import PY3, text_type
40 from samba.samdb import SamDB
42 import samba.dcerpc.dcerpc
43 import samba.dcerpc.epmapper
46 from unittest import SkipTest
48 class SkipTest(Exception):
51 HEXDUMP_FILTER=''.join([(len(repr(chr(x)))==3) and chr(x) or '.' for x in range(256)])
53 class TestCase(unittest.TestCase):
54 """A Samba test case."""
57 super(TestCase, self).setUp()
58 test_debug_level = os.getenv("TEST_DEBUG_LEVEL")
59 if test_debug_level is not None:
60 test_debug_level = int(test_debug_level)
61 self._old_debug_level = samba.get_debug_level()
62 samba.set_debug_level(test_debug_level)
63 self.addCleanup(samba.set_debug_level, test_debug_level)
65 def get_loadparm(self):
68 def get_credentials(self):
69 return cmdline_credentials
71 def get_creds_ccache_name(self):
72 creds = self.get_credentials()
73 ccache = creds.get_named_ccache(self.get_loadparm())
74 ccache_name = ccache.get_name()
78 def hexdump(self, src):
85 hl = ' '.join(["%02X" % ord(x) for x in ll])
86 hr = ' '.join(["%02X" % ord(x) for x in lr])
87 ll = ll.translate(HEXDUMP_FILTER)
88 lr = lr.translate(HEXDUMP_FILTER)
89 result += "[%04X] %-*s %-*s %s %s\n" % (N, 8*3, hl, 8*3, hr, ll, lr)
93 def insta_creds(self, template=None, username=None, userpass=None, kerberos_state=None):
96 assert template is not None
98 if username is not None:
99 assert userpass is not None
102 assert userpass is None
104 username = template.get_username()
105 userpass = template.get_password()
107 if kerberos_state is None:
108 kerberos_state = template.get_kerberos_state()
110 # get a copy of the global creds or a the passed in creds
112 c.set_username(username)
113 c.set_password(userpass)
114 c.set_domain(template.get_domain())
115 c.set_realm(template.get_realm())
116 c.set_workstation(template.get_workstation())
117 c.set_gensec_features(c.get_gensec_features()
118 | gensec.FEATURE_SEAL)
119 c.set_kerberos_state(kerberos_state)
124 # These functions didn't exist before Python2.7:
125 if sys.version_info < (2, 7):
128 def skipTest(self, reason):
129 raise SkipTest(reason)
131 def assertIn(self, member, container, msg=None):
132 self.assertTrue(member in container, msg)
134 def assertIs(self, a, b, msg=None):
135 self.assertTrue(a is b, msg)
137 def assertIsNot(self, a, b, msg=None):
138 self.assertTrue(a is not b, msg)
140 def assertIsNotNone(self, a, msg=None):
141 self.assertTrue(a is not None)
143 def assertIsInstance(self, a, b, msg=None):
144 self.assertTrue(isinstance(a, b), msg)
146 def assertIsNone(self, a, msg=None):
147 self.assertTrue(a is None, msg)
149 def assertGreater(self, a, b, msg=None):
150 self.assertTrue(a > b, msg)
152 def assertGreaterEqual(self, a, b, msg=None):
153 self.assertTrue(a >= b, msg)
155 def assertLess(self, a, b, msg=None):
156 self.assertTrue(a < b, msg)
158 def assertLessEqual(self, a, b, msg=None):
159 self.assertTrue(a <= b, msg)
161 def addCleanup(self, fn, *args, **kwargs):
162 self._cleanups = getattr(self, "_cleanups", []) + [
165 def assertRegexpMatches(self, text, regex, msg=None):
166 # PY3 note: Python 3 will never see this, but we use
167 # text_type for the benefit of linters.
168 if isinstance(regex, (str, text_type)):
169 regex = re.compile(regex)
170 if not regex.search(text):
173 def _addSkip(self, result, reason):
174 addSkip = getattr(result, 'addSkip', None)
175 if addSkip is not None:
176 addSkip(self, reason)
178 warnings.warn("TestResult has no addSkip method, skips not reported",
180 result.addSuccess(self)
182 def run(self, result=None):
183 if result is None: result = self.defaultTestResult()
184 result.startTest(self)
185 testMethod = getattr(self, self._testMethodName)
189 except SkipTest as e:
190 self._addSkip(result, str(e))
192 except KeyboardInterrupt:
195 result.addError(self, self._exc_info())
202 except SkipTest as e:
203 self._addSkip(result, str(e))
205 except self.failureException:
206 result.addFailure(self, self._exc_info())
207 except KeyboardInterrupt:
210 result.addError(self, self._exc_info())
214 except SkipTest as e:
215 self._addSkip(result, str(e))
216 except KeyboardInterrupt:
219 result.addError(self, self._exc_info())
222 for (fn, args, kwargs) in reversed(getattr(self, "_cleanups", [])):
224 if ok: result.addSuccess(self)
226 result.stopTest(self)
228 def assertStringsEqual(self, a, b, msg=None, strip=False):
229 """Assert equality between two strings and highlight any differences.
230 If strip is true, leading and trailing whitespace is ignored."""
236 sys.stderr.write("The strings differ %s(lengths %d vs %d); "
238 % ('when stripped ' if strip else '',
242 from difflib import unified_diff
243 diff = unified_diff(a.splitlines(True),
247 sys.stderr.write(line)
252 class LdbTestCase(TestCase):
253 """Trivial test case for running tests against a LDB."""
256 super(LdbTestCase, self).setUp()
257 self.filename = os.tempnam()
258 self.ldb = samba.Ldb(self.filename)
260 def set_modules(self, modules=[]):
261 """Change the modules for this Ldb."""
263 m.dn = ldb.Dn(self.ldb, "@MODULES")
264 m["@LIST"] = ",".join(modules)
266 self.ldb = samba.Ldb(self.filename)
269 class TestCaseInTempDir(TestCase):
272 super(TestCaseInTempDir, self).setUp()
273 self.tempdir = tempfile.mkdtemp()
274 self.addCleanup(self._remove_tempdir)
276 def _remove_tempdir(self):
277 self.assertEquals([], os.listdir(self.tempdir))
278 os.rmdir(self.tempdir)
283 lp = param.LoadParm()
285 lp.load(os.environ["SMB_CONF_PATH"])
287 raise KeyError("SMB_CONF_PATH not set")
291 def env_get_var_value(var_name, allow_missing=False):
292 """Returns value for variable in os.environ
294 Function throws AssertionError if variable is defined.
295 Unit-test based python tests require certain input params
296 to be set in environment, otherwise they can't be run
299 if var_name not in os.environ.keys():
301 assert var_name in os.environ.keys(), "Please supply %s in environment" % var_name
302 return os.environ[var_name]
305 cmdline_credentials = None
307 class RpcInterfaceTestCase(TestCase):
308 """DCE/RPC Test case."""
311 class ValidNetbiosNameTests(TestCase):
313 def test_valid(self):
314 self.assertTrue(samba.valid_netbios_name("FOO"))
316 def test_too_long(self):
317 self.assertFalse(samba.valid_netbios_name("FOO"*10))
319 def test_invalid_characters(self):
320 self.assertFalse(samba.valid_netbios_name("*BLA"))
323 class BlackboxProcessError(Exception):
324 """This is raised when check_output() process returns a non-zero exit status
326 Exception instance should contain the exact exit code (S.returncode),
327 command line (S.cmd), process output (S.stdout) and process error stream
331 def __init__(self, returncode, cmd, stdout, stderr):
332 self.returncode = returncode
338 return "Command '%s'; exit status %d; stdout: '%s'; stderr: '%s'" % (self.cmd, self.returncode,
339 self.stdout, self.stderr)
341 class BlackboxTestCase(TestCaseInTempDir):
342 """Base test case for blackbox tests."""
344 def _make_cmdline(self, line):
345 bindir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../../bin"))
346 parts = line.split(" ")
347 if os.path.exists(os.path.join(bindir, parts[0])):
348 parts[0] = os.path.join(bindir, parts[0])
349 line = " ".join(parts)
352 def check_run(self, line):
353 self.check_exit_code(line, 0)
355 def check_exit_code(self, line, expected):
356 line = self._make_cmdline(line)
357 p = subprocess.Popen(line,
358 stdout=subprocess.PIPE,
359 stderr=subprocess.PIPE,
361 stdoutdata, stderrdata = p.communicate()
362 retcode = p.returncode
363 if retcode != expected:
364 raise BlackboxProcessError(retcode,
369 def check_output(self, line):
370 line = self._make_cmdline(line)
371 p = subprocess.Popen(line, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True, close_fds=True)
372 stdoutdata, stderrdata = p.communicate()
373 retcode = p.returncode
375 raise BlackboxProcessError(retcode, line, stdoutdata, stderrdata)
379 def connect_samdb(samdb_url, lp=None, session_info=None, credentials=None,
380 flags=0, ldb_options=None, ldap_only=False, global_schema=True):
381 """Create SamDB instance and connects to samdb_url database.
383 :param samdb_url: Url for database to connect to.
384 :param lp: Optional loadparm object
385 :param session_info: Optional session information
386 :param credentials: Optional credentials, defaults to anonymous.
387 :param flags: Optional LDB flags
388 :param ldap_only: If set, only remote LDAP connection will be created.
389 :param global_schema: Whether to use global schema.
391 Added value for tests is that we have a shorthand function
392 to make proper URL for ldb.connect() while using default
393 parameters for connection based on test environment
395 if not "://" in samdb_url:
396 if not ldap_only and os.path.isfile(samdb_url):
397 samdb_url = "tdb://%s" % samdb_url
399 samdb_url = "ldap://%s" % samdb_url
400 # use 'paged_search' module when connecting remotely
401 if samdb_url.startswith("ldap://"):
402 ldb_options = ["modules:paged_searches"]
404 raise AssertionError("Trying to connect to %s while remote "
405 "connection is required" % samdb_url)
407 # set defaults for test environment
410 if session_info is None:
411 session_info = samba.auth.system_session(lp)
412 if credentials is None:
413 credentials = cmdline_credentials
415 return SamDB(url=samdb_url,
417 session_info=session_info,
418 credentials=credentials,
421 global_schema=global_schema)
424 def connect_samdb_ex(samdb_url, lp=None, session_info=None, credentials=None,
425 flags=0, ldb_options=None, ldap_only=False):
426 """Connects to samdb_url database
428 :param samdb_url: Url for database to connect to.
429 :param lp: Optional loadparm object
430 :param session_info: Optional session information
431 :param credentials: Optional credentials, defaults to anonymous.
432 :param flags: Optional LDB flags
433 :param ldap_only: If set, only remote LDAP connection will be created.
434 :return: (sam_db_connection, rootDse_record) tuple
436 sam_db = connect_samdb(samdb_url, lp, session_info, credentials,
437 flags, ldb_options, ldap_only)
439 res = sam_db.search(base="", expression="", scope=ldb.SCOPE_BASE,
441 return (sam_db, res[0])
444 def connect_samdb_env(env_url, env_username, env_password, lp=None):
445 """Connect to SamDB by getting URL and Credentials from environment
447 :param env_url: Environment variable name to get lsb url from
448 :param env_username: Username environment variable
449 :param env_password: Password environment variable
450 :return: sam_db_connection
452 samdb_url = env_get_var_value(env_url)
453 creds = credentials.Credentials()
455 # guess Credentials parameters here. Otherwise workstation
456 # and domain fields are NULL and gencache code segfalts
457 lp = param.LoadParm()
459 creds.set_username(env_get_var_value(env_username))
460 creds.set_password(env_get_var_value(env_password))
461 return connect_samdb(samdb_url, credentials=creds, lp=lp)
464 def delete_force(samdb, dn, **kwargs):
466 samdb.delete(dn, **kwargs)
467 except ldb.LdbError as error:
468 (num, errstr) = error.args
469 assert num == ldb.ERR_NO_SUCH_OBJECT, "ldb.delete() failed: %s" % errstr