python: samba.gensec: Port module to Python 3 compatible form
[samba.git] / python / samba / tests / __init__.py
1 # Unix SMB/CIFS implementation.
2 # Copyright (C) Jelmer Vernooij <jelmer@samba.org> 2007-2010
3 # Copyright (C) Stefan Metzmacher 2014,2015
4 #
5 # This program is free software; you can redistribute it and/or modify
6 # it under the terms of the GNU General Public License as published by
7 # the Free Software Foundation; either version 3 of the License, or
8 # (at your option) any later version.
9 #
10 # This program is distributed in the hope that it will be useful,
11 # but WITHOUT ANY WARRANTY; without even the implied warranty of
12 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 # GNU General Public License for more details.
14 #
15 # You should have received a copy of the GNU General Public License
16 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
17 #
18
19 """Samba Python tests."""
20
21 import os
22 import ldb
23 import samba
24 from samba import param
25 from samba import credentials
26 from samba.credentials import Credentials
27 import socket
28 import struct
29 import subprocess
30 import sys
31 import tempfile
32 import unittest
33 import samba.auth
34 import samba.dcerpc.base
35 from samba.compat import PY3
36 if not PY3:
37     # Py2 only
38     from samba.samdb import SamDB
39     import samba.ndr
40     import samba.dcerpc.dcerpc
41     import samba.dcerpc.epmapper
42
43 try:
44     from unittest import SkipTest
45 except ImportError:
46     class SkipTest(Exception):
47         """Test skipped."""
48
49 HEXDUMP_FILTER=''.join([(len(repr(chr(x)))==3) and chr(x) or '.' for x in range(256)])
50
51 class TestCase(unittest.TestCase):
52     """A Samba test case."""
53
54     def setUp(self):
55         super(TestCase, self).setUp()
56         test_debug_level = os.getenv("TEST_DEBUG_LEVEL")
57         if test_debug_level is not None:
58             test_debug_level = int(test_debug_level)
59             self._old_debug_level = samba.get_debug_level()
60             samba.set_debug_level(test_debug_level)
61             self.addCleanup(samba.set_debug_level, test_debug_level)
62
63     def get_loadparm(self):
64         return env_loadparm()
65
66     def get_credentials(self):
67         return cmdline_credentials
68
69     def hexdump(self, src):
70         N = 0
71         result = ''
72         while src:
73             ll = src[:8]
74             lr = src[8:16]
75             src = src[16:]
76             hl = ' '.join(["%02X" % ord(x) for x in ll])
77             hr = ' '.join(["%02X" % ord(x) for x in lr])
78             ll = ll.translate(HEXDUMP_FILTER)
79             lr = lr.translate(HEXDUMP_FILTER)
80             result += "[%04X] %-*s  %-*s  %s %s\n" % (N, 8*3, hl, 8*3, hr, ll, lr)
81             N += 16
82         return result
83
84     # These functions didn't exist before Python2.7:
85     if sys.version_info < (2, 7):
86         import warnings
87
88         def skipTest(self, reason):
89             raise SkipTest(reason)
90
91         def assertIn(self, member, container, msg=None):
92             self.assertTrue(member in container, msg)
93
94         def assertIs(self, a, b, msg=None):
95             self.assertTrue(a is b, msg)
96
97         def assertIsNot(self, a, b, msg=None):
98             self.assertTrue(a is not b, msg)
99
100         def assertIsNotNone(self, a, msg=None):
101             self.assertTrue(a is not None)
102
103         def assertIsInstance(self, a, b, msg=None):
104             self.assertTrue(isinstance(a, b), msg)
105
106         def assertIsNone(self, a, msg=None):
107             self.assertTrue(a is None, msg)
108
109         def assertGreater(self, a, b, msg=None):
110             self.assertTrue(a > b, msg)
111
112         def assertGreaterEqual(self, a, b, msg=None):
113             self.assertTrue(a >= b, msg)
114
115         def assertLess(self, a, b, msg=None):
116             self.assertTrue(a < b, msg)
117
118         def assertLessEqual(self, a, b, msg=None):
119             self.assertTrue(a <= b, msg)
120
121         def addCleanup(self, fn, *args, **kwargs):
122             self._cleanups = getattr(self, "_cleanups", []) + [
123                 (fn, args, kwargs)]
124
125         def _addSkip(self, result, reason):
126             addSkip = getattr(result, 'addSkip', None)
127             if addSkip is not None:
128                 addSkip(self, reason)
129             else:
130                 warnings.warn("TestResult has no addSkip method, skips not reported",
131                               RuntimeWarning, 2)
132                 result.addSuccess(self)
133
134         def run(self, result=None):
135             if result is None: result = self.defaultTestResult()
136             result.startTest(self)
137             testMethod = getattr(self, self._testMethodName)
138             try:
139                 try:
140                     self.setUp()
141                 except SkipTest as e:
142                     self._addSkip(result, str(e))
143                     return
144                 except KeyboardInterrupt:
145                     raise
146                 except:
147                     result.addError(self, self._exc_info())
148                     return
149
150                 ok = False
151                 try:
152                     testMethod()
153                     ok = True
154                 except SkipTest as e:
155                     self._addSkip(result, str(e))
156                     return
157                 except self.failureException:
158                     result.addFailure(self, self._exc_info())
159                 except KeyboardInterrupt:
160                     raise
161                 except:
162                     result.addError(self, self._exc_info())
163
164                 try:
165                     self.tearDown()
166                 except SkipTest as e:
167                     self._addSkip(result, str(e))
168                 except KeyboardInterrupt:
169                     raise
170                 except:
171                     result.addError(self, self._exc_info())
172                     ok = False
173
174                 for (fn, args, kwargs) in reversed(getattr(self, "_cleanups", [])):
175                     fn(*args, **kwargs)
176                 if ok: result.addSuccess(self)
177             finally:
178                 result.stopTest(self)
179
180
181 class LdbTestCase(TestCase):
182     """Trivial test case for running tests against a LDB."""
183
184     def setUp(self):
185         super(LdbTestCase, self).setUp()
186         self.filename = os.tempnam()
187         self.ldb = samba.Ldb(self.filename)
188
189     def set_modules(self, modules=[]):
190         """Change the modules for this Ldb."""
191         m = ldb.Message()
192         m.dn = ldb.Dn(self.ldb, "@MODULES")
193         m["@LIST"] = ",".join(modules)
194         self.ldb.add(m)
195         self.ldb = samba.Ldb(self.filename)
196
197
198 class TestCaseInTempDir(TestCase):
199
200     def setUp(self):
201         super(TestCaseInTempDir, self).setUp()
202         self.tempdir = tempfile.mkdtemp()
203         self.addCleanup(self._remove_tempdir)
204
205     def _remove_tempdir(self):
206         self.assertEquals([], os.listdir(self.tempdir))
207         os.rmdir(self.tempdir)
208         self.tempdir = None
209
210
211 def env_loadparm():
212     lp = param.LoadParm()
213     try:
214         lp.load(os.environ["SMB_CONF_PATH"])
215     except KeyError:
216         raise KeyError("SMB_CONF_PATH not set")
217     return lp
218
219
220 def env_get_var_value(var_name, allow_missing=False):
221     """Returns value for variable in os.environ
222
223     Function throws AssertionError if variable is defined.
224     Unit-test based python tests require certain input params
225     to be set in environment, otherwise they can't be run
226     """
227     if allow_missing:
228         if var_name not in os.environ.keys():
229             return None
230     assert var_name in os.environ.keys(), "Please supply %s in environment" % var_name
231     return os.environ[var_name]
232
233
234 cmdline_credentials = None
235
236 class RpcInterfaceTestCase(TestCase):
237     """DCE/RPC Test case."""
238
239
240 class ValidNetbiosNameTests(TestCase):
241
242     def test_valid(self):
243         self.assertTrue(samba.valid_netbios_name("FOO"))
244
245     def test_too_long(self):
246         self.assertFalse(samba.valid_netbios_name("FOO"*10))
247
248     def test_invalid_characters(self):
249         self.assertFalse(samba.valid_netbios_name("*BLA"))
250
251
252 class BlackboxProcessError(Exception):
253     """This is raised when check_output() process returns a non-zero exit status
254
255     Exception instance should contain the exact exit code (S.returncode),
256     command line (S.cmd), process output (S.stdout) and process error stream
257     (S.stderr)
258     """
259
260     def __init__(self, returncode, cmd, stdout, stderr):
261         self.returncode = returncode
262         self.cmd = cmd
263         self.stdout = stdout
264         self.stderr = stderr
265
266     def __str__(self):
267         return "Command '%s'; exit status %d; stdout: '%s'; stderr: '%s'" % (self.cmd, self.returncode,
268                                                                              self.stdout, self.stderr)
269
270 class BlackboxTestCase(TestCaseInTempDir):
271     """Base test case for blackbox tests."""
272
273     def _make_cmdline(self, line):
274         bindir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../../bin"))
275         parts = line.split(" ")
276         if os.path.exists(os.path.join(bindir, parts[0])):
277             parts[0] = os.path.join(bindir, parts[0])
278         line = " ".join(parts)
279         return line
280
281     def check_run(self, line):
282         line = self._make_cmdline(line)
283         p = subprocess.Popen(line, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
284         retcode = p.wait()
285         if retcode:
286             raise BlackboxProcessError(retcode, line, p.stdout.read(), p.stderr.read())
287
288     def check_output(self, line):
289         line = self._make_cmdline(line)
290         p = subprocess.Popen(line, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True, close_fds=True)
291         retcode = p.wait()
292         if retcode:
293             raise BlackboxProcessError(retcode, line, p.stdout.read(), p.stderr.read())
294         return p.stdout.read()
295
296
297 def connect_samdb(samdb_url, lp=None, session_info=None, credentials=None,
298                   flags=0, ldb_options=None, ldap_only=False, global_schema=True):
299     """Create SamDB instance and connects to samdb_url database.
300
301     :param samdb_url: Url for database to connect to.
302     :param lp: Optional loadparm object
303     :param session_info: Optional session information
304     :param credentials: Optional credentials, defaults to anonymous.
305     :param flags: Optional LDB flags
306     :param ldap_only: If set, only remote LDAP connection will be created.
307     :param global_schema: Whether to use global schema.
308
309     Added value for tests is that we have a shorthand function
310     to make proper URL for ldb.connect() while using default
311     parameters for connection based on test environment
312     """
313     if not "://" in samdb_url:
314         if not ldap_only and os.path.isfile(samdb_url):
315             samdb_url = "tdb://%s" % samdb_url
316         else:
317             samdb_url = "ldap://%s" % samdb_url
318     # use 'paged_search' module when connecting remotely
319     if samdb_url.startswith("ldap://"):
320         ldb_options = ["modules:paged_searches"]
321     elif ldap_only:
322         raise AssertionError("Trying to connect to %s while remote "
323                              "connection is required" % samdb_url)
324
325     # set defaults for test environment
326     if lp is None:
327         lp = env_loadparm()
328     if session_info is None:
329         session_info = samba.auth.system_session(lp)
330     if credentials is None:
331         credentials = cmdline_credentials
332
333     return SamDB(url=samdb_url,
334                  lp=lp,
335                  session_info=session_info,
336                  credentials=credentials,
337                  flags=flags,
338                  options=ldb_options,
339                  global_schema=global_schema)
340
341
342 def connect_samdb_ex(samdb_url, lp=None, session_info=None, credentials=None,
343                      flags=0, ldb_options=None, ldap_only=False):
344     """Connects to samdb_url database
345
346     :param samdb_url: Url for database to connect to.
347     :param lp: Optional loadparm object
348     :param session_info: Optional session information
349     :param credentials: Optional credentials, defaults to anonymous.
350     :param flags: Optional LDB flags
351     :param ldap_only: If set, only remote LDAP connection will be created.
352     :return: (sam_db_connection, rootDse_record) tuple
353     """
354     sam_db = connect_samdb(samdb_url, lp, session_info, credentials,
355                            flags, ldb_options, ldap_only)
356     # fetch RootDse
357     res = sam_db.search(base="", expression="", scope=ldb.SCOPE_BASE,
358                         attrs=["*"])
359     return (sam_db, res[0])
360
361
362 def connect_samdb_env(env_url, env_username, env_password, lp=None):
363     """Connect to SamDB by getting URL and Credentials from environment
364
365     :param env_url: Environment variable name to get lsb url from
366     :param env_username: Username environment variable
367     :param env_password: Password environment variable
368     :return: sam_db_connection
369     """
370     samdb_url = env_get_var_value(env_url)
371     creds = credentials.Credentials()
372     if lp is None:
373         # guess Credentials parameters here. Otherwise workstation
374         # and domain fields are NULL and gencache code segfalts
375         lp = param.LoadParm()
376         creds.guess(lp)
377     creds.set_username(env_get_var_value(env_username))
378     creds.set_password(env_get_var_value(env_password))
379     return connect_samdb(samdb_url, credentials=creds, lp=lp)
380
381
382 def delete_force(samdb, dn):
383     try:
384         samdb.delete(dn)
385     except ldb.LdbError as error:
386         (num, errstr) = error.args
387         assert num == ldb.ERR_NO_SUCH_OBJECT, "ldb.delete() failed: %s" % errstr