1 # Unix SMB/CIFS implementation.
2 # Copyright (C) Jelmer Vernooij <jelmer@samba.org> 2007-2010
3 # Copyright (C) Stefan Metzmacher 2014,2015
5 # This program is free software; you can redistribute it and/or modify
6 # it under the terms of the GNU General Public License as published by
7 # the Free Software Foundation; either version 3 of the License, or
8 # (at your option) any later version.
10 # This program is distributed in the hope that it will be useful,
11 # but WITHOUT ANY WARRANTY; without even the implied warranty of
12 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 # GNU General Public License for more details.
15 # You should have received a copy of the GNU General Public License
16 # along with this program. If not, see <http://www.gnu.org/licenses/>.
19 """Samba Python tests."""
24 from samba import param
25 from samba import credentials
26 from samba.credentials import Credentials
34 import samba.dcerpc.base
35 from samba.compat import PY3
38 from samba.samdb import SamDB
40 import samba.dcerpc.dcerpc
41 import samba.dcerpc.epmapper
44 from unittest import SkipTest
46 class SkipTest(Exception):
49 HEXDUMP_FILTER=''.join([(len(repr(chr(x)))==3) and chr(x) or '.' for x in range(256)])
51 class TestCase(unittest.TestCase):
52 """A Samba test case."""
55 super(TestCase, self).setUp()
56 test_debug_level = os.getenv("TEST_DEBUG_LEVEL")
57 if test_debug_level is not None:
58 test_debug_level = int(test_debug_level)
59 self._old_debug_level = samba.get_debug_level()
60 samba.set_debug_level(test_debug_level)
61 self.addCleanup(samba.set_debug_level, test_debug_level)
63 def get_loadparm(self):
66 def get_credentials(self):
67 return cmdline_credentials
69 def hexdump(self, src):
76 hl = ' '.join(["%02X" % ord(x) for x in ll])
77 hr = ' '.join(["%02X" % ord(x) for x in lr])
78 ll = ll.translate(HEXDUMP_FILTER)
79 lr = lr.translate(HEXDUMP_FILTER)
80 result += "[%04X] %-*s %-*s %s %s\n" % (N, 8*3, hl, 8*3, hr, ll, lr)
84 # These functions didn't exist before Python2.7:
85 if sys.version_info < (2, 7):
88 def skipTest(self, reason):
89 raise SkipTest(reason)
91 def assertIn(self, member, container, msg=None):
92 self.assertTrue(member in container, msg)
94 def assertIs(self, a, b, msg=None):
95 self.assertTrue(a is b, msg)
97 def assertIsNot(self, a, b, msg=None):
98 self.assertTrue(a is not b, msg)
100 def assertIsNotNone(self, a, msg=None):
101 self.assertTrue(a is not None)
103 def assertIsInstance(self, a, b, msg=None):
104 self.assertTrue(isinstance(a, b), msg)
106 def assertIsNone(self, a, msg=None):
107 self.assertTrue(a is None, msg)
109 def assertGreater(self, a, b, msg=None):
110 self.assertTrue(a > b, msg)
112 def assertGreaterEqual(self, a, b, msg=None):
113 self.assertTrue(a >= b, msg)
115 def assertLess(self, a, b, msg=None):
116 self.assertTrue(a < b, msg)
118 def assertLessEqual(self, a, b, msg=None):
119 self.assertTrue(a <= b, msg)
121 def addCleanup(self, fn, *args, **kwargs):
122 self._cleanups = getattr(self, "_cleanups", []) + [
125 def _addSkip(self, result, reason):
126 addSkip = getattr(result, 'addSkip', None)
127 if addSkip is not None:
128 addSkip(self, reason)
130 warnings.warn("TestResult has no addSkip method, skips not reported",
132 result.addSuccess(self)
134 def run(self, result=None):
135 if result is None: result = self.defaultTestResult()
136 result.startTest(self)
137 testMethod = getattr(self, self._testMethodName)
141 except SkipTest as e:
142 self._addSkip(result, str(e))
144 except KeyboardInterrupt:
147 result.addError(self, self._exc_info())
154 except SkipTest as e:
155 self._addSkip(result, str(e))
157 except self.failureException:
158 result.addFailure(self, self._exc_info())
159 except KeyboardInterrupt:
162 result.addError(self, self._exc_info())
166 except SkipTest as e:
167 self._addSkip(result, str(e))
168 except KeyboardInterrupt:
171 result.addError(self, self._exc_info())
174 for (fn, args, kwargs) in reversed(getattr(self, "_cleanups", [])):
176 if ok: result.addSuccess(self)
178 result.stopTest(self)
181 class LdbTestCase(TestCase):
182 """Trivial test case for running tests against a LDB."""
185 super(LdbTestCase, self).setUp()
186 self.filename = os.tempnam()
187 self.ldb = samba.Ldb(self.filename)
189 def set_modules(self, modules=[]):
190 """Change the modules for this Ldb."""
192 m.dn = ldb.Dn(self.ldb, "@MODULES")
193 m["@LIST"] = ",".join(modules)
195 self.ldb = samba.Ldb(self.filename)
198 class TestCaseInTempDir(TestCase):
201 super(TestCaseInTempDir, self).setUp()
202 self.tempdir = tempfile.mkdtemp()
203 self.addCleanup(self._remove_tempdir)
205 def _remove_tempdir(self):
206 self.assertEquals([], os.listdir(self.tempdir))
207 os.rmdir(self.tempdir)
212 lp = param.LoadParm()
214 lp.load(os.environ["SMB_CONF_PATH"])
216 raise KeyError("SMB_CONF_PATH not set")
220 def env_get_var_value(var_name, allow_missing=False):
221 """Returns value for variable in os.environ
223 Function throws AssertionError if variable is defined.
224 Unit-test based python tests require certain input params
225 to be set in environment, otherwise they can't be run
228 if var_name not in os.environ.keys():
230 assert var_name in os.environ.keys(), "Please supply %s in environment" % var_name
231 return os.environ[var_name]
234 cmdline_credentials = None
236 class RpcInterfaceTestCase(TestCase):
237 """DCE/RPC Test case."""
240 class ValidNetbiosNameTests(TestCase):
242 def test_valid(self):
243 self.assertTrue(samba.valid_netbios_name("FOO"))
245 def test_too_long(self):
246 self.assertFalse(samba.valid_netbios_name("FOO"*10))
248 def test_invalid_characters(self):
249 self.assertFalse(samba.valid_netbios_name("*BLA"))
252 class BlackboxProcessError(Exception):
253 """This is raised when check_output() process returns a non-zero exit status
255 Exception instance should contain the exact exit code (S.returncode),
256 command line (S.cmd), process output (S.stdout) and process error stream
260 def __init__(self, returncode, cmd, stdout, stderr):
261 self.returncode = returncode
267 return "Command '%s'; exit status %d; stdout: '%s'; stderr: '%s'" % (self.cmd, self.returncode,
268 self.stdout, self.stderr)
270 class BlackboxTestCase(TestCaseInTempDir):
271 """Base test case for blackbox tests."""
273 def _make_cmdline(self, line):
274 bindir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../../bin"))
275 parts = line.split(" ")
276 if os.path.exists(os.path.join(bindir, parts[0])):
277 parts[0] = os.path.join(bindir, parts[0])
278 line = " ".join(parts)
281 def check_run(self, line):
282 line = self._make_cmdline(line)
283 p = subprocess.Popen(line, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
286 raise BlackboxProcessError(retcode, line, p.stdout.read(), p.stderr.read())
288 def check_output(self, line):
289 line = self._make_cmdline(line)
290 p = subprocess.Popen(line, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True, close_fds=True)
293 raise BlackboxProcessError(retcode, line, p.stdout.read(), p.stderr.read())
294 return p.stdout.read()
297 def connect_samdb(samdb_url, lp=None, session_info=None, credentials=None,
298 flags=0, ldb_options=None, ldap_only=False, global_schema=True):
299 """Create SamDB instance and connects to samdb_url database.
301 :param samdb_url: Url for database to connect to.
302 :param lp: Optional loadparm object
303 :param session_info: Optional session information
304 :param credentials: Optional credentials, defaults to anonymous.
305 :param flags: Optional LDB flags
306 :param ldap_only: If set, only remote LDAP connection will be created.
307 :param global_schema: Whether to use global schema.
309 Added value for tests is that we have a shorthand function
310 to make proper URL for ldb.connect() while using default
311 parameters for connection based on test environment
313 if not "://" in samdb_url:
314 if not ldap_only and os.path.isfile(samdb_url):
315 samdb_url = "tdb://%s" % samdb_url
317 samdb_url = "ldap://%s" % samdb_url
318 # use 'paged_search' module when connecting remotely
319 if samdb_url.startswith("ldap://"):
320 ldb_options = ["modules:paged_searches"]
322 raise AssertionError("Trying to connect to %s while remote "
323 "connection is required" % samdb_url)
325 # set defaults for test environment
328 if session_info is None:
329 session_info = samba.auth.system_session(lp)
330 if credentials is None:
331 credentials = cmdline_credentials
333 return SamDB(url=samdb_url,
335 session_info=session_info,
336 credentials=credentials,
339 global_schema=global_schema)
342 def connect_samdb_ex(samdb_url, lp=None, session_info=None, credentials=None,
343 flags=0, ldb_options=None, ldap_only=False):
344 """Connects to samdb_url database
346 :param samdb_url: Url for database to connect to.
347 :param lp: Optional loadparm object
348 :param session_info: Optional session information
349 :param credentials: Optional credentials, defaults to anonymous.
350 :param flags: Optional LDB flags
351 :param ldap_only: If set, only remote LDAP connection will be created.
352 :return: (sam_db_connection, rootDse_record) tuple
354 sam_db = connect_samdb(samdb_url, lp, session_info, credentials,
355 flags, ldb_options, ldap_only)
357 res = sam_db.search(base="", expression="", scope=ldb.SCOPE_BASE,
359 return (sam_db, res[0])
362 def connect_samdb_env(env_url, env_username, env_password, lp=None):
363 """Connect to SamDB by getting URL and Credentials from environment
365 :param env_url: Environment variable name to get lsb url from
366 :param env_username: Username environment variable
367 :param env_password: Password environment variable
368 :return: sam_db_connection
370 samdb_url = env_get_var_value(env_url)
371 creds = credentials.Credentials()
373 # guess Credentials parameters here. Otherwise workstation
374 # and domain fields are NULL and gencache code segfalts
375 lp = param.LoadParm()
377 creds.set_username(env_get_var_value(env_username))
378 creds.set_password(env_get_var_value(env_password))
379 return connect_samdb(samdb_url, credentials=creds, lp=lp)
382 def delete_force(samdb, dn):
385 except ldb.LdbError as error:
386 (num, errstr) = error.args
387 assert num == ldb.ERR_NO_SUCH_OBJECT, "ldb.delete() failed: %s" % errstr