1 # Unix SMB/CIFS implementation.
2 # Copyright (C) Jelmer Vernooij <jelmer@samba.org> 2007-2010
3 # Copyright (C) Stefan Metzmacher 2014,2015
5 # This program is free software; you can redistribute it and/or modify
6 # it under the terms of the GNU General Public License as published by
7 # the Free Software Foundation; either version 3 of the License, or
8 # (at your option) any later version.
10 # This program is distributed in the hope that it will be useful,
11 # but WITHOUT ANY WARRANTY; without even the implied warranty of
12 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 # GNU General Public License for more details.
15 # You should have received a copy of the GNU General Public License
16 # along with this program. If not, see <http://www.gnu.org/licenses/>.
19 """Samba Python tests."""
24 from samba import param
25 from samba import credentials
26 from samba.credentials import Credentials
27 from samba import gensec
35 import samba.dcerpc.base
36 from samba.compat import PY3
39 from samba.samdb import SamDB
41 import samba.dcerpc.dcerpc
42 import samba.dcerpc.epmapper
45 from unittest import SkipTest
47 class SkipTest(Exception):
50 HEXDUMP_FILTER=''.join([(len(repr(chr(x)))==3) and chr(x) or '.' for x in range(256)])
52 class TestCase(unittest.TestCase):
53 """A Samba test case."""
56 super(TestCase, self).setUp()
57 test_debug_level = os.getenv("TEST_DEBUG_LEVEL")
58 if test_debug_level is not None:
59 test_debug_level = int(test_debug_level)
60 self._old_debug_level = samba.get_debug_level()
61 samba.set_debug_level(test_debug_level)
62 self.addCleanup(samba.set_debug_level, test_debug_level)
64 def get_loadparm(self):
67 def get_credentials(self):
68 return cmdline_credentials
70 def get_creds_ccache_name(self):
71 creds = self.get_credentials()
72 ccache = creds.get_named_ccache(self.get_loadparm())
73 ccache_name = ccache.get_name()
77 def hexdump(self, src):
84 hl = ' '.join(["%02X" % ord(x) for x in ll])
85 hr = ' '.join(["%02X" % ord(x) for x in lr])
86 ll = ll.translate(HEXDUMP_FILTER)
87 lr = lr.translate(HEXDUMP_FILTER)
88 result += "[%04X] %-*s %-*s %s %s\n" % (N, 8*3, hl, 8*3, hr, ll, lr)
92 def insta_creds(self, template=None, username=None, userpass=None, kerberos_state=None):
95 assert template is not None
97 if username is not None:
98 assert userpass is not None
101 assert userpass is None
103 username = template.get_username()
104 userpass = template.get_password()
106 if kerberos_state is None:
107 kerberos_state = template.get_kerberos_state()
109 # get a copy of the global creds or a the passed in creds
111 c.set_username(username)
112 c.set_password(userpass)
113 c.set_domain(template.get_domain())
114 c.set_realm(template.get_realm())
115 c.set_workstation(template.get_workstation())
116 c.set_gensec_features(c.get_gensec_features()
117 | gensec.FEATURE_SEAL)
118 c.set_kerberos_state(kerberos_state)
123 # These functions didn't exist before Python2.7:
124 if sys.version_info < (2, 7):
127 def skipTest(self, reason):
128 raise SkipTest(reason)
130 def assertIn(self, member, container, msg=None):
131 self.assertTrue(member in container, msg)
133 def assertIs(self, a, b, msg=None):
134 self.assertTrue(a is b, msg)
136 def assertIsNot(self, a, b, msg=None):
137 self.assertTrue(a is not b, msg)
139 def assertIsNotNone(self, a, msg=None):
140 self.assertTrue(a is not None)
142 def assertIsInstance(self, a, b, msg=None):
143 self.assertTrue(isinstance(a, b), msg)
145 def assertIsNone(self, a, msg=None):
146 self.assertTrue(a is None, msg)
148 def assertGreater(self, a, b, msg=None):
149 self.assertTrue(a > b, msg)
151 def assertGreaterEqual(self, a, b, msg=None):
152 self.assertTrue(a >= b, msg)
154 def assertLess(self, a, b, msg=None):
155 self.assertTrue(a < b, msg)
157 def assertLessEqual(self, a, b, msg=None):
158 self.assertTrue(a <= b, msg)
160 def addCleanup(self, fn, *args, **kwargs):
161 self._cleanups = getattr(self, "_cleanups", []) + [
164 def _addSkip(self, result, reason):
165 addSkip = getattr(result, 'addSkip', None)
166 if addSkip is not None:
167 addSkip(self, reason)
169 warnings.warn("TestResult has no addSkip method, skips not reported",
171 result.addSuccess(self)
173 def run(self, result=None):
174 if result is None: result = self.defaultTestResult()
175 result.startTest(self)
176 testMethod = getattr(self, self._testMethodName)
180 except SkipTest as e:
181 self._addSkip(result, str(e))
183 except KeyboardInterrupt:
186 result.addError(self, self._exc_info())
193 except SkipTest as e:
194 self._addSkip(result, str(e))
196 except self.failureException:
197 result.addFailure(self, self._exc_info())
198 except KeyboardInterrupt:
201 result.addError(self, self._exc_info())
205 except SkipTest as e:
206 self._addSkip(result, str(e))
207 except KeyboardInterrupt:
210 result.addError(self, self._exc_info())
213 for (fn, args, kwargs) in reversed(getattr(self, "_cleanups", [])):
215 if ok: result.addSuccess(self)
217 result.stopTest(self)
219 def assertStringsEqual(self, a, b, msg=None, strip=False):
220 """Assert equality between two strings and highlight any differences.
221 If strip is true, leading and trailing whitespace is ignored."""
227 sys.stderr.write("The strings differ %s(lengths %d vs %d); "
229 % ('when stripped ' if strip else '',
233 from difflib import unified_diff
234 diff = unified_diff(a.splitlines(True),
238 sys.stderr.write(line)
243 class LdbTestCase(TestCase):
244 """Trivial test case for running tests against a LDB."""
247 super(LdbTestCase, self).setUp()
248 self.filename = os.tempnam()
249 self.ldb = samba.Ldb(self.filename)
251 def set_modules(self, modules=[]):
252 """Change the modules for this Ldb."""
254 m.dn = ldb.Dn(self.ldb, "@MODULES")
255 m["@LIST"] = ",".join(modules)
257 self.ldb = samba.Ldb(self.filename)
260 class TestCaseInTempDir(TestCase):
263 super(TestCaseInTempDir, self).setUp()
264 self.tempdir = tempfile.mkdtemp()
265 self.addCleanup(self._remove_tempdir)
267 def _remove_tempdir(self):
268 self.assertEquals([], os.listdir(self.tempdir))
269 os.rmdir(self.tempdir)
274 lp = param.LoadParm()
276 lp.load(os.environ["SMB_CONF_PATH"])
278 raise KeyError("SMB_CONF_PATH not set")
282 def env_get_var_value(var_name, allow_missing=False):
283 """Returns value for variable in os.environ
285 Function throws AssertionError if variable is defined.
286 Unit-test based python tests require certain input params
287 to be set in environment, otherwise they can't be run
290 if var_name not in os.environ.keys():
292 assert var_name in os.environ.keys(), "Please supply %s in environment" % var_name
293 return os.environ[var_name]
296 cmdline_credentials = None
298 class RpcInterfaceTestCase(TestCase):
299 """DCE/RPC Test case."""
302 class ValidNetbiosNameTests(TestCase):
304 def test_valid(self):
305 self.assertTrue(samba.valid_netbios_name("FOO"))
307 def test_too_long(self):
308 self.assertFalse(samba.valid_netbios_name("FOO"*10))
310 def test_invalid_characters(self):
311 self.assertFalse(samba.valid_netbios_name("*BLA"))
314 class BlackboxProcessError(Exception):
315 """This is raised when check_output() process returns a non-zero exit status
317 Exception instance should contain the exact exit code (S.returncode),
318 command line (S.cmd), process output (S.stdout) and process error stream
322 def __init__(self, returncode, cmd, stdout, stderr):
323 self.returncode = returncode
329 return "Command '%s'; exit status %d; stdout: '%s'; stderr: '%s'" % (self.cmd, self.returncode,
330 self.stdout, self.stderr)
332 class BlackboxTestCase(TestCaseInTempDir):
333 """Base test case for blackbox tests."""
335 def _make_cmdline(self, line):
336 bindir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../../bin"))
337 parts = line.split(" ")
338 if os.path.exists(os.path.join(bindir, parts[0])):
339 parts[0] = os.path.join(bindir, parts[0])
340 line = " ".join(parts)
343 def check_run(self, line):
344 self.check_exit_code(line, 0)
346 def check_exit_code(self, line, expected):
347 line = self._make_cmdline(line)
348 p = subprocess.Popen(line,
349 stdout=subprocess.PIPE,
350 stderr=subprocess.PIPE,
352 stdoutdata, stderrdata = p.communicate()
353 retcode = p.returncode
354 if retcode != expected:
355 raise BlackboxProcessError(retcode,
360 def check_output(self, line):
361 line = self._make_cmdline(line)
362 p = subprocess.Popen(line, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True, close_fds=True)
363 stdoutdata, stderrdata = p.communicate()
364 retcode = p.returncode
366 raise BlackboxProcessError(retcode, line, stdoutdata, stderrdata)
370 def connect_samdb(samdb_url, lp=None, session_info=None, credentials=None,
371 flags=0, ldb_options=None, ldap_only=False, global_schema=True):
372 """Create SamDB instance and connects to samdb_url database.
374 :param samdb_url: Url for database to connect to.
375 :param lp: Optional loadparm object
376 :param session_info: Optional session information
377 :param credentials: Optional credentials, defaults to anonymous.
378 :param flags: Optional LDB flags
379 :param ldap_only: If set, only remote LDAP connection will be created.
380 :param global_schema: Whether to use global schema.
382 Added value for tests is that we have a shorthand function
383 to make proper URL for ldb.connect() while using default
384 parameters for connection based on test environment
386 if not "://" in samdb_url:
387 if not ldap_only and os.path.isfile(samdb_url):
388 samdb_url = "tdb://%s" % samdb_url
390 samdb_url = "ldap://%s" % samdb_url
391 # use 'paged_search' module when connecting remotely
392 if samdb_url.startswith("ldap://"):
393 ldb_options = ["modules:paged_searches"]
395 raise AssertionError("Trying to connect to %s while remote "
396 "connection is required" % samdb_url)
398 # set defaults for test environment
401 if session_info is None:
402 session_info = samba.auth.system_session(lp)
403 if credentials is None:
404 credentials = cmdline_credentials
406 return SamDB(url=samdb_url,
408 session_info=session_info,
409 credentials=credentials,
412 global_schema=global_schema)
415 def connect_samdb_ex(samdb_url, lp=None, session_info=None, credentials=None,
416 flags=0, ldb_options=None, ldap_only=False):
417 """Connects to samdb_url database
419 :param samdb_url: Url for database to connect to.
420 :param lp: Optional loadparm object
421 :param session_info: Optional session information
422 :param credentials: Optional credentials, defaults to anonymous.
423 :param flags: Optional LDB flags
424 :param ldap_only: If set, only remote LDAP connection will be created.
425 :return: (sam_db_connection, rootDse_record) tuple
427 sam_db = connect_samdb(samdb_url, lp, session_info, credentials,
428 flags, ldb_options, ldap_only)
430 res = sam_db.search(base="", expression="", scope=ldb.SCOPE_BASE,
432 return (sam_db, res[0])
435 def connect_samdb_env(env_url, env_username, env_password, lp=None):
436 """Connect to SamDB by getting URL and Credentials from environment
438 :param env_url: Environment variable name to get lsb url from
439 :param env_username: Username environment variable
440 :param env_password: Password environment variable
441 :return: sam_db_connection
443 samdb_url = env_get_var_value(env_url)
444 creds = credentials.Credentials()
446 # guess Credentials parameters here. Otherwise workstation
447 # and domain fields are NULL and gencache code segfalts
448 lp = param.LoadParm()
450 creds.set_username(env_get_var_value(env_username))
451 creds.set_password(env_get_var_value(env_password))
452 return connect_samdb(samdb_url, credentials=creds, lp=lp)
455 def delete_force(samdb, dn, **kwargs):
457 samdb.delete(dn, **kwargs)
458 except ldb.LdbError as error:
459 (num, errstr) = error.args
460 assert num == ldb.ERR_NO_SUCH_OBJECT, "ldb.delete() failed: %s" % errstr