1 # Unix SMB/CIFS implementation.
2 # Copyright (C) Jelmer Vernooij <jelmer@samba.org> 2007-2010
3 # Copyright (C) Stefan Metzmacher 2014,2015
5 # This program is free software; you can redistribute it and/or modify
6 # it under the terms of the GNU General Public License as published by
7 # the Free Software Foundation; either version 3 of the License, or
8 # (at your option) any later version.
10 # This program is distributed in the hope that it will be useful,
11 # but WITHOUT ANY WARRANTY; without even the implied warranty of
12 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 # GNU General Public License for more details.
15 # You should have received a copy of the GNU General Public License
16 # along with this program. If not, see <http://www.gnu.org/licenses/>.
19 """Samba Python tests."""
24 from samba import param
25 from samba import credentials
26 from samba.credentials import Credentials
34 import samba.dcerpc.base
35 from samba.compat import PY3
38 from samba.samdb import SamDB
40 import samba.dcerpc.dcerpc
41 import samba.dcerpc.epmapper
42 from samba import gensec
45 from unittest import SkipTest
47 class SkipTest(Exception):
50 HEXDUMP_FILTER=''.join([(len(repr(chr(x)))==3) and chr(x) or '.' for x in range(256)])
52 class TestCase(unittest.TestCase):
53 """A Samba test case."""
56 super(TestCase, self).setUp()
57 test_debug_level = os.getenv("TEST_DEBUG_LEVEL")
58 if test_debug_level is not None:
59 test_debug_level = int(test_debug_level)
60 self._old_debug_level = samba.get_debug_level()
61 samba.set_debug_level(test_debug_level)
62 self.addCleanup(samba.set_debug_level, test_debug_level)
64 def get_loadparm(self):
67 def get_credentials(self):
68 return cmdline_credentials
70 def hexdump(self, src):
77 hl = ' '.join(["%02X" % ord(x) for x in ll])
78 hr = ' '.join(["%02X" % ord(x) for x in lr])
79 ll = ll.translate(HEXDUMP_FILTER)
80 lr = lr.translate(HEXDUMP_FILTER)
81 result += "[%04X] %-*s %-*s %s %s\n" % (N, 8*3, hl, 8*3, hr, ll, lr)
85 # These functions didn't exist before Python2.7:
86 if sys.version_info < (2, 7):
89 def skipTest(self, reason):
90 raise SkipTest(reason)
92 def assertIn(self, member, container, msg=None):
93 self.assertTrue(member in container, msg)
95 def assertIs(self, a, b, msg=None):
96 self.assertTrue(a is b, msg)
98 def assertIsNot(self, a, b, msg=None):
99 self.assertTrue(a is not b, msg)
101 def assertIsNotNone(self, a, msg=None):
102 self.assertTrue(a is not None)
104 def assertIsInstance(self, a, b, msg=None):
105 self.assertTrue(isinstance(a, b), msg)
107 def assertIsNone(self, a, msg=None):
108 self.assertTrue(a is None, msg)
110 def assertGreater(self, a, b, msg=None):
111 self.assertTrue(a > b, msg)
113 def assertGreaterEqual(self, a, b, msg=None):
114 self.assertTrue(a >= b, msg)
116 def assertLess(self, a, b, msg=None):
117 self.assertTrue(a < b, msg)
119 def assertLessEqual(self, a, b, msg=None):
120 self.assertTrue(a <= b, msg)
122 def addCleanup(self, fn, *args, **kwargs):
123 self._cleanups = getattr(self, "_cleanups", []) + [
126 def _addSkip(self, result, reason):
127 addSkip = getattr(result, 'addSkip', None)
128 if addSkip is not None:
129 addSkip(self, reason)
131 warnings.warn("TestResult has no addSkip method, skips not reported",
133 result.addSuccess(self)
135 def run(self, result=None):
136 if result is None: result = self.defaultTestResult()
137 result.startTest(self)
138 testMethod = getattr(self, self._testMethodName)
142 except SkipTest as e:
143 self._addSkip(result, str(e))
145 except KeyboardInterrupt:
148 result.addError(self, self._exc_info())
155 except SkipTest as e:
156 self._addSkip(result, str(e))
158 except self.failureException:
159 result.addFailure(self, self._exc_info())
160 except KeyboardInterrupt:
163 result.addError(self, self._exc_info())
167 except SkipTest as e:
168 self._addSkip(result, str(e))
169 except KeyboardInterrupt:
172 result.addError(self, self._exc_info())
175 for (fn, args, kwargs) in reversed(getattr(self, "_cleanups", [])):
177 if ok: result.addSuccess(self)
179 result.stopTest(self)
182 class LdbTestCase(TestCase):
183 """Trivial test case for running tests against a LDB."""
186 super(LdbTestCase, self).setUp()
187 self.filename = os.tempnam()
188 self.ldb = samba.Ldb(self.filename)
190 def set_modules(self, modules=[]):
191 """Change the modules for this Ldb."""
193 m.dn = ldb.Dn(self.ldb, "@MODULES")
194 m["@LIST"] = ",".join(modules)
196 self.ldb = samba.Ldb(self.filename)
199 class TestCaseInTempDir(TestCase):
202 super(TestCaseInTempDir, self).setUp()
203 self.tempdir = tempfile.mkdtemp()
204 self.addCleanup(self._remove_tempdir)
206 def _remove_tempdir(self):
207 self.assertEquals([], os.listdir(self.tempdir))
208 os.rmdir(self.tempdir)
213 lp = param.LoadParm()
215 lp.load(os.environ["SMB_CONF_PATH"])
217 raise KeyError("SMB_CONF_PATH not set")
221 def env_get_var_value(var_name, allow_missing=False):
222 """Returns value for variable in os.environ
224 Function throws AssertionError if variable is defined.
225 Unit-test based python tests require certain input params
226 to be set in environment, otherwise they can't be run
229 if var_name not in os.environ.keys():
231 assert var_name in os.environ.keys(), "Please supply %s in environment" % var_name
232 return os.environ[var_name]
235 cmdline_credentials = None
237 class RpcInterfaceTestCase(TestCase):
238 """DCE/RPC Test case."""
241 class ValidNetbiosNameTests(TestCase):
243 def test_valid(self):
244 self.assertTrue(samba.valid_netbios_name("FOO"))
246 def test_too_long(self):
247 self.assertFalse(samba.valid_netbios_name("FOO"*10))
249 def test_invalid_characters(self):
250 self.assertFalse(samba.valid_netbios_name("*BLA"))
253 class BlackboxProcessError(Exception):
254 """This is raised when check_output() process returns a non-zero exit status
256 Exception instance should contain the exact exit code (S.returncode),
257 command line (S.cmd), process output (S.stdout) and process error stream
261 def __init__(self, returncode, cmd, stdout, stderr):
262 self.returncode = returncode
268 return "Command '%s'; exit status %d; stdout: '%s'; stderr: '%s'" % (self.cmd, self.returncode,
269 self.stdout, self.stderr)
271 class BlackboxTestCase(TestCaseInTempDir):
272 """Base test case for blackbox tests."""
274 def _make_cmdline(self, line):
275 bindir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../../bin"))
276 parts = line.split(" ")
277 if os.path.exists(os.path.join(bindir, parts[0])):
278 parts[0] = os.path.join(bindir, parts[0])
279 line = " ".join(parts)
282 def check_run(self, line):
283 line = self._make_cmdline(line)
284 p = subprocess.Popen(line, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
287 raise BlackboxProcessError(retcode, line, p.stdout.read(), p.stderr.read())
289 def check_output(self, line):
290 line = self._make_cmdline(line)
291 p = subprocess.Popen(line, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True, close_fds=True)
294 raise BlackboxProcessError(retcode, line, p.stdout.read(), p.stderr.read())
295 return p.stdout.read()
298 def connect_samdb(samdb_url, lp=None, session_info=None, credentials=None,
299 flags=0, ldb_options=None, ldap_only=False, global_schema=True):
300 """Create SamDB instance and connects to samdb_url database.
302 :param samdb_url: Url for database to connect to.
303 :param lp: Optional loadparm object
304 :param session_info: Optional session information
305 :param credentials: Optional credentials, defaults to anonymous.
306 :param flags: Optional LDB flags
307 :param ldap_only: If set, only remote LDAP connection will be created.
308 :param global_schema: Whether to use global schema.
310 Added value for tests is that we have a shorthand function
311 to make proper URL for ldb.connect() while using default
312 parameters for connection based on test environment
314 if not "://" in samdb_url:
315 if not ldap_only and os.path.isfile(samdb_url):
316 samdb_url = "tdb://%s" % samdb_url
318 samdb_url = "ldap://%s" % samdb_url
319 # use 'paged_search' module when connecting remotely
320 if samdb_url.startswith("ldap://"):
321 ldb_options = ["modules:paged_searches"]
323 raise AssertionError("Trying to connect to %s while remote "
324 "connection is required" % samdb_url)
326 # set defaults for test environment
329 if session_info is None:
330 session_info = samba.auth.system_session(lp)
331 if credentials is None:
332 credentials = cmdline_credentials
334 return SamDB(url=samdb_url,
336 session_info=session_info,
337 credentials=credentials,
340 global_schema=global_schema)
343 def connect_samdb_ex(samdb_url, lp=None, session_info=None, credentials=None,
344 flags=0, ldb_options=None, ldap_only=False):
345 """Connects to samdb_url database
347 :param samdb_url: Url for database to connect to.
348 :param lp: Optional loadparm object
349 :param session_info: Optional session information
350 :param credentials: Optional credentials, defaults to anonymous.
351 :param flags: Optional LDB flags
352 :param ldap_only: If set, only remote LDAP connection will be created.
353 :return: (sam_db_connection, rootDse_record) tuple
355 sam_db = connect_samdb(samdb_url, lp, session_info, credentials,
356 flags, ldb_options, ldap_only)
358 res = sam_db.search(base="", expression="", scope=ldb.SCOPE_BASE,
360 return (sam_db, res[0])
363 def connect_samdb_env(env_url, env_username, env_password, lp=None):
364 """Connect to SamDB by getting URL and Credentials from environment
366 :param env_url: Environment variable name to get lsb url from
367 :param env_username: Username environment variable
368 :param env_password: Password environment variable
369 :return: sam_db_connection
371 samdb_url = env_get_var_value(env_url)
372 creds = credentials.Credentials()
374 # guess Credentials parameters here. Otherwise workstation
375 # and domain fields are NULL and gencache code segfalts
376 lp = param.LoadParm()
378 creds.set_username(env_get_var_value(env_username))
379 creds.set_password(env_get_var_value(env_password))
380 return connect_samdb(samdb_url, credentials=creds, lp=lp)
383 def delete_force(samdb, dn):
386 except ldb.LdbError as error:
387 (num, errstr) = error.args
388 assert num == ldb.ERR_NO_SUCH_OBJECT, "ldb.delete() failed: %s" % errstr