2ddfd9d2273f398c72d9e5ef24f32e40d7ef61f3
[samba.git] / python / samba / tests / __init__.py
1 # Unix SMB/CIFS implementation.
2 # Copyright (C) Jelmer Vernooij <jelmer@samba.org> 2007-2010
3 # Copyright (C) Stefan Metzmacher 2014,2015
4 #
5 # This program is free software; you can redistribute it and/or modify
6 # it under the terms of the GNU General Public License as published by
7 # the Free Software Foundation; either version 3 of the License, or
8 # (at your option) any later version.
9 #
10 # This program is distributed in the hope that it will be useful,
11 # but WITHOUT ANY WARRANTY; without even the implied warranty of
12 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 # GNU General Public License for more details.
14 #
15 # You should have received a copy of the GNU General Public License
16 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
17 #
18
19 """Samba Python tests."""
20
21 import os
22 import ldb
23 import samba
24 from samba import param
25 from samba import credentials
26 from samba.credentials import Credentials
27 from samba import gensec
28 import socket
29 import struct
30 import subprocess
31 import sys
32 import tempfile
33 import unittest
34 import samba.auth
35 import samba.dcerpc.base
36 from samba.compat import PY3
37 if not PY3:
38     # Py2 only
39     from samba.samdb import SamDB
40     import samba.ndr
41     import samba.dcerpc.dcerpc
42     import samba.dcerpc.epmapper
43
44 try:
45     from unittest import SkipTest
46 except ImportError:
47     class SkipTest(Exception):
48         """Test skipped."""
49
50 HEXDUMP_FILTER=''.join([(len(repr(chr(x)))==3) and chr(x) or '.' for x in range(256)])
51
52 class TestCase(unittest.TestCase):
53     """A Samba test case."""
54
55     def setUp(self):
56         super(TestCase, self).setUp()
57         test_debug_level = os.getenv("TEST_DEBUG_LEVEL")
58         if test_debug_level is not None:
59             test_debug_level = int(test_debug_level)
60             self._old_debug_level = samba.get_debug_level()
61             samba.set_debug_level(test_debug_level)
62             self.addCleanup(samba.set_debug_level, test_debug_level)
63
64     def get_loadparm(self):
65         return env_loadparm()
66
67     def get_credentials(self):
68         return cmdline_credentials
69
70     def get_creds_ccache_name(self):
71         creds = self.get_credentials()
72         ccache = creds.get_named_ccache(self.get_loadparm())
73         ccache_name = ccache.get_name()
74
75         return ccache_name
76
77     def hexdump(self, src):
78         N = 0
79         result = ''
80         while src:
81             ll = src[:8]
82             lr = src[8:16]
83             src = src[16:]
84             hl = ' '.join(["%02X" % ord(x) for x in ll])
85             hr = ' '.join(["%02X" % ord(x) for x in lr])
86             ll = ll.translate(HEXDUMP_FILTER)
87             lr = lr.translate(HEXDUMP_FILTER)
88             result += "[%04X] %-*s  %-*s  %s %s\n" % (N, 8*3, hl, 8*3, hr, ll, lr)
89             N += 16
90         return result
91
92     def insta_creds(self, template=None, username=None, userpass=None, kerberos_state=None):
93
94         if template is None:
95             assert template is not None
96
97         if username is not None:
98             assert userpass is not None
99
100         if username is None:
101             assert userpass is None
102
103             username = template.get_username()
104             userpass = template.get_password()
105
106         if kerberos_state is None:
107             kerberos_state = template.get_kerberos_state()
108
109         # get a copy of the global creds or a the passed in creds
110         c = Credentials()
111         c.set_username(username)
112         c.set_password(userpass)
113         c.set_domain(template.get_domain())
114         c.set_realm(template.get_realm())
115         c.set_workstation(template.get_workstation())
116         c.set_gensec_features(c.get_gensec_features()
117                               | gensec.FEATURE_SEAL)
118         c.set_kerberos_state(kerberos_state)
119         return c
120
121
122
123     # These functions didn't exist before Python2.7:
124     if sys.version_info < (2, 7):
125         import warnings
126
127         def skipTest(self, reason):
128             raise SkipTest(reason)
129
130         def assertIn(self, member, container, msg=None):
131             self.assertTrue(member in container, msg)
132
133         def assertIs(self, a, b, msg=None):
134             self.assertTrue(a is b, msg)
135
136         def assertIsNot(self, a, b, msg=None):
137             self.assertTrue(a is not b, msg)
138
139         def assertIsNotNone(self, a, msg=None):
140             self.assertTrue(a is not None)
141
142         def assertIsInstance(self, a, b, msg=None):
143             self.assertTrue(isinstance(a, b), msg)
144
145         def assertIsNone(self, a, msg=None):
146             self.assertTrue(a is None, msg)
147
148         def assertGreater(self, a, b, msg=None):
149             self.assertTrue(a > b, msg)
150
151         def assertGreaterEqual(self, a, b, msg=None):
152             self.assertTrue(a >= b, msg)
153
154         def assertLess(self, a, b, msg=None):
155             self.assertTrue(a < b, msg)
156
157         def assertLessEqual(self, a, b, msg=None):
158             self.assertTrue(a <= b, msg)
159
160         def addCleanup(self, fn, *args, **kwargs):
161             self._cleanups = getattr(self, "_cleanups", []) + [
162                 (fn, args, kwargs)]
163
164         def _addSkip(self, result, reason):
165             addSkip = getattr(result, 'addSkip', None)
166             if addSkip is not None:
167                 addSkip(self, reason)
168             else:
169                 warnings.warn("TestResult has no addSkip method, skips not reported",
170                               RuntimeWarning, 2)
171                 result.addSuccess(self)
172
173         def run(self, result=None):
174             if result is None: result = self.defaultTestResult()
175             result.startTest(self)
176             testMethod = getattr(self, self._testMethodName)
177             try:
178                 try:
179                     self.setUp()
180                 except SkipTest as e:
181                     self._addSkip(result, str(e))
182                     return
183                 except KeyboardInterrupt:
184                     raise
185                 except:
186                     result.addError(self, self._exc_info())
187                     return
188
189                 ok = False
190                 try:
191                     testMethod()
192                     ok = True
193                 except SkipTest as e:
194                     self._addSkip(result, str(e))
195                     return
196                 except self.failureException:
197                     result.addFailure(self, self._exc_info())
198                 except KeyboardInterrupt:
199                     raise
200                 except:
201                     result.addError(self, self._exc_info())
202
203                 try:
204                     self.tearDown()
205                 except SkipTest as e:
206                     self._addSkip(result, str(e))
207                 except KeyboardInterrupt:
208                     raise
209                 except:
210                     result.addError(self, self._exc_info())
211                     ok = False
212
213                 for (fn, args, kwargs) in reversed(getattr(self, "_cleanups", [])):
214                     fn(*args, **kwargs)
215                 if ok: result.addSuccess(self)
216             finally:
217                 result.stopTest(self)
218
219
220 class LdbTestCase(TestCase):
221     """Trivial test case for running tests against a LDB."""
222
223     def setUp(self):
224         super(LdbTestCase, self).setUp()
225         self.filename = os.tempnam()
226         self.ldb = samba.Ldb(self.filename)
227
228     def set_modules(self, modules=[]):
229         """Change the modules for this Ldb."""
230         m = ldb.Message()
231         m.dn = ldb.Dn(self.ldb, "@MODULES")
232         m["@LIST"] = ",".join(modules)
233         self.ldb.add(m)
234         self.ldb = samba.Ldb(self.filename)
235
236
237 class TestCaseInTempDir(TestCase):
238
239     def setUp(self):
240         super(TestCaseInTempDir, self).setUp()
241         self.tempdir = tempfile.mkdtemp()
242         self.addCleanup(self._remove_tempdir)
243
244     def _remove_tempdir(self):
245         self.assertEquals([], os.listdir(self.tempdir))
246         os.rmdir(self.tempdir)
247         self.tempdir = None
248
249
250 def env_loadparm():
251     lp = param.LoadParm()
252     try:
253         lp.load(os.environ["SMB_CONF_PATH"])
254     except KeyError:
255         raise KeyError("SMB_CONF_PATH not set")
256     return lp
257
258
259 def env_get_var_value(var_name, allow_missing=False):
260     """Returns value for variable in os.environ
261
262     Function throws AssertionError if variable is defined.
263     Unit-test based python tests require certain input params
264     to be set in environment, otherwise they can't be run
265     """
266     if allow_missing:
267         if var_name not in os.environ.keys():
268             return None
269     assert var_name in os.environ.keys(), "Please supply %s in environment" % var_name
270     return os.environ[var_name]
271
272
273 cmdline_credentials = None
274
275 class RpcInterfaceTestCase(TestCase):
276     """DCE/RPC Test case."""
277
278
279 class ValidNetbiosNameTests(TestCase):
280
281     def test_valid(self):
282         self.assertTrue(samba.valid_netbios_name("FOO"))
283
284     def test_too_long(self):
285         self.assertFalse(samba.valid_netbios_name("FOO"*10))
286
287     def test_invalid_characters(self):
288         self.assertFalse(samba.valid_netbios_name("*BLA"))
289
290
291 class BlackboxProcessError(Exception):
292     """This is raised when check_output() process returns a non-zero exit status
293
294     Exception instance should contain the exact exit code (S.returncode),
295     command line (S.cmd), process output (S.stdout) and process error stream
296     (S.stderr)
297     """
298
299     def __init__(self, returncode, cmd, stdout, stderr):
300         self.returncode = returncode
301         self.cmd = cmd
302         self.stdout = stdout
303         self.stderr = stderr
304
305     def __str__(self):
306         return "Command '%s'; exit status %d; stdout: '%s'; stderr: '%s'" % (self.cmd, self.returncode,
307                                                                              self.stdout, self.stderr)
308
309 class BlackboxTestCase(TestCaseInTempDir):
310     """Base test case for blackbox tests."""
311
312     def _make_cmdline(self, line):
313         bindir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../../bin"))
314         parts = line.split(" ")
315         if os.path.exists(os.path.join(bindir, parts[0])):
316             parts[0] = os.path.join(bindir, parts[0])
317         line = " ".join(parts)
318         return line
319
320     def check_run(self, line):
321         line = self._make_cmdline(line)
322         p = subprocess.Popen(line, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
323         retcode = p.wait()
324         if retcode:
325             raise BlackboxProcessError(retcode, line, p.stdout.read(), p.stderr.read())
326
327     def check_output(self, line):
328         line = self._make_cmdline(line)
329         p = subprocess.Popen(line, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True, close_fds=True)
330         retcode = p.wait()
331         if retcode:
332             raise BlackboxProcessError(retcode, line, p.stdout.read(), p.stderr.read())
333         return p.stdout.read()
334
335
336 def connect_samdb(samdb_url, lp=None, session_info=None, credentials=None,
337                   flags=0, ldb_options=None, ldap_only=False, global_schema=True):
338     """Create SamDB instance and connects to samdb_url database.
339
340     :param samdb_url: Url for database to connect to.
341     :param lp: Optional loadparm object
342     :param session_info: Optional session information
343     :param credentials: Optional credentials, defaults to anonymous.
344     :param flags: Optional LDB flags
345     :param ldap_only: If set, only remote LDAP connection will be created.
346     :param global_schema: Whether to use global schema.
347
348     Added value for tests is that we have a shorthand function
349     to make proper URL for ldb.connect() while using default
350     parameters for connection based on test environment
351     """
352     if not "://" in samdb_url:
353         if not ldap_only and os.path.isfile(samdb_url):
354             samdb_url = "tdb://%s" % samdb_url
355         else:
356             samdb_url = "ldap://%s" % samdb_url
357     # use 'paged_search' module when connecting remotely
358     if samdb_url.startswith("ldap://"):
359         ldb_options = ["modules:paged_searches"]
360     elif ldap_only:
361         raise AssertionError("Trying to connect to %s while remote "
362                              "connection is required" % samdb_url)
363
364     # set defaults for test environment
365     if lp is None:
366         lp = env_loadparm()
367     if session_info is None:
368         session_info = samba.auth.system_session(lp)
369     if credentials is None:
370         credentials = cmdline_credentials
371
372     return SamDB(url=samdb_url,
373                  lp=lp,
374                  session_info=session_info,
375                  credentials=credentials,
376                  flags=flags,
377                  options=ldb_options,
378                  global_schema=global_schema)
379
380
381 def connect_samdb_ex(samdb_url, lp=None, session_info=None, credentials=None,
382                      flags=0, ldb_options=None, ldap_only=False):
383     """Connects to samdb_url database
384
385     :param samdb_url: Url for database to connect to.
386     :param lp: Optional loadparm object
387     :param session_info: Optional session information
388     :param credentials: Optional credentials, defaults to anonymous.
389     :param flags: Optional LDB flags
390     :param ldap_only: If set, only remote LDAP connection will be created.
391     :return: (sam_db_connection, rootDse_record) tuple
392     """
393     sam_db = connect_samdb(samdb_url, lp, session_info, credentials,
394                            flags, ldb_options, ldap_only)
395     # fetch RootDse
396     res = sam_db.search(base="", expression="", scope=ldb.SCOPE_BASE,
397                         attrs=["*"])
398     return (sam_db, res[0])
399
400
401 def connect_samdb_env(env_url, env_username, env_password, lp=None):
402     """Connect to SamDB by getting URL and Credentials from environment
403
404     :param env_url: Environment variable name to get lsb url from
405     :param env_username: Username environment variable
406     :param env_password: Password environment variable
407     :return: sam_db_connection
408     """
409     samdb_url = env_get_var_value(env_url)
410     creds = credentials.Credentials()
411     if lp is None:
412         # guess Credentials parameters here. Otherwise workstation
413         # and domain fields are NULL and gencache code segfalts
414         lp = param.LoadParm()
415         creds.guess(lp)
416     creds.set_username(env_get_var_value(env_username))
417     creds.set_password(env_get_var_value(env_password))
418     return connect_samdb(samdb_url, credentials=creds, lp=lp)
419
420
421 def delete_force(samdb, dn, **kwargs):
422     try:
423         samdb.delete(dn, **kwargs)
424     except ldb.LdbError as error:
425         (num, errstr) = error.args
426         assert num == ldb.ERR_NO_SUCH_OBJECT, "ldb.delete() failed: %s" % errstr