Handle skips when running on python2.6.
[samba.git] / python / samba / tests / __init__.py
1 # Unix SMB/CIFS implementation.
2 # Copyright (C) Jelmer Vernooij <jelmer@samba.org> 2007-2010
3 #
4 # This program is free software; you can redistribute it and/or modify
5 # it under the terms of the GNU General Public License as published by
6 # the Free Software Foundation; either version 3 of the License, or
7 # (at your option) any later version.
8 #
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12 # GNU General Public License for more details.
13 #
14 # You should have received a copy of the GNU General Public License
15 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
16 #
17
18 """Samba Python tests."""
19
20 import os
21 import ldb
22 import samba
23 import samba.auth
24 from samba import param
25 from samba.samdb import SamDB
26 from samba import credentials
27 import subprocess
28 import sys
29 import tempfile
30 import unittest
31
32 try:
33     from unittest import SkipTest
34 except ImportError:
35     class SkipTest(Exception):
36         """Test skipped."""
37
38
39 class TestCase(unittest.TestCase):
40     """A Samba test case."""
41
42     def setUp(self):
43         super(TestCase, self).setUp()
44         test_debug_level = os.getenv("TEST_DEBUG_LEVEL")
45         if test_debug_level is not None:
46             test_debug_level = int(test_debug_level)
47             self._old_debug_level = samba.get_debug_level()
48             samba.set_debug_level(test_debug_level)
49             self.addCleanup(samba.set_debug_level, test_debug_level)
50
51     def get_loadparm(self):
52         return env_loadparm()
53
54     def get_credentials(self):
55         return cmdline_credentials
56
57     # These functions didn't exist before Python2.7:
58     if sys.version_info < (2, 7):
59         import warnings
60
61         def skipTest(self, reason):
62             raise SkipTest(reason)
63
64         def assertIs(self, a, b):
65             self.assertTrue(a is b)
66
67         def assertIsNot(self, a, b):
68             self.assertTrue(a is not b)
69
70         def assertIsInstance(self, a, b):
71             self.assertTrue(isinstance(a, b))
72
73         def addCleanup(self, fn, *args, **kwargs):
74             self._cleanups = getattr(self, "_cleanups", []) + [
75                 (fn, args, kwargs)]
76
77         def _addSkip(self, result, reason):
78             addSkip = getattr(result, 'addSkip', None)
79             if addSkip is not None:
80                 addSkip(self, reason)
81             else:
82                 warnings.warn("TestResult has no addSkip method, skips not reported",
83                               RuntimeWarning, 2)
84                 result.addSuccess(self)
85
86         def run(self, result=None):
87             if result is None: result = self.defaultTestResult()
88             result.startTest(self)
89             testMethod = getattr(self, self._testMethodName)
90             try:
91                 try:
92                     self.setUp()
93                 except SkipTest, e:
94                     self._addSkip(result, str(e))
95                     return
96                 except KeyboardInterrupt:
97                     raise
98                 except:
99                     result.addError(self, self._exc_info())
100                     return
101
102                 ok = False
103                 try:
104                     testMethod()
105                     ok = True
106                 except SkipTest, e:
107                     self._addSkip(result, str(e))
108                     return
109                 except self.failureException:
110                     result.addFailure(self, self._exc_info())
111                 except KeyboardInterrupt:
112                     raise
113                 except:
114                     result.addError(self, self._exc_info())
115
116                 try:
117                     self.tearDown()
118                 except SkipTest, e:
119                     self._addSkip(result, str(e))
120                 except KeyboardInterrupt:
121                     raise
122                 except:
123                     result.addError(self, self._exc_info())
124                     ok = False
125
126                 for (fn, args, kwargs) in reversed(getattr(self, "_cleanups", [])):
127                     fn(*args, **kwargs)
128                 if ok: result.addSuccess(self)
129             finally:
130                 result.stopTest(self)
131
132
133 class LdbTestCase(TestCase):
134     """Trivial test case for running tests against a LDB."""
135
136     def setUp(self):
137         super(LdbTestCase, self).setUp()
138         self.filename = os.tempnam()
139         self.ldb = samba.Ldb(self.filename)
140
141     def set_modules(self, modules=[]):
142         """Change the modules for this Ldb."""
143         m = ldb.Message()
144         m.dn = ldb.Dn(self.ldb, "@MODULES")
145         m["@LIST"] = ",".join(modules)
146         self.ldb.add(m)
147         self.ldb = samba.Ldb(self.filename)
148
149
150 class TestCaseInTempDir(TestCase):
151
152     def setUp(self):
153         super(TestCaseInTempDir, self).setUp()
154         self.tempdir = tempfile.mkdtemp()
155         self.addCleanup(self._remove_tempdir)
156
157     def _remove_tempdir(self):
158         self.assertEquals([], os.listdir(self.tempdir))
159         os.rmdir(self.tempdir)
160         self.tempdir = None
161
162
163 def env_loadparm():
164     lp = param.LoadParm()
165     try:
166         lp.load(os.environ["SMB_CONF_PATH"])
167     except KeyError:
168         raise KeyError("SMB_CONF_PATH not set")
169     return lp
170
171
172 def env_get_var_value(var_name):
173     """Returns value for variable in os.environ
174
175     Function throws AssertionError if variable is defined.
176     Unit-test based python tests require certain input params
177     to be set in environment, otherwise they can't be run
178     """
179     assert var_name in os.environ.keys(), "Please supply %s in environment" % var_name
180     return os.environ[var_name]
181
182
183 cmdline_credentials = None
184
185 class RpcInterfaceTestCase(TestCase):
186     """DCE/RPC Test case."""
187
188
189 class ValidNetbiosNameTests(TestCase):
190
191     def test_valid(self):
192         self.assertTrue(samba.valid_netbios_name("FOO"))
193
194     def test_too_long(self):
195         self.assertFalse(samba.valid_netbios_name("FOO"*10))
196
197     def test_invalid_characters(self):
198         self.assertFalse(samba.valid_netbios_name("*BLA"))
199
200
201 class BlackboxProcessError(Exception):
202     """This is raised when check_output() process returns a non-zero exit status
203
204     Exception instance should contain the exact exit code (S.returncode),
205     command line (S.cmd), process output (S.stdout) and process error stream
206     (S.stderr)
207     """
208
209     def __init__(self, returncode, cmd, stdout, stderr):
210         self.returncode = returncode
211         self.cmd = cmd
212         self.stdout = stdout
213         self.stderr = stderr
214
215     def __str__(self):
216         return "Command '%s'; exit status %d; stdout: '%s'; stderr: '%s'" % (self.cmd, self.returncode,
217                                                                              self.stdout, self.stderr)
218
219 class BlackboxTestCase(TestCase):
220     """Base test case for blackbox tests."""
221
222     def _make_cmdline(self, line):
223         bindir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../../bin"))
224         parts = line.split(" ")
225         if os.path.exists(os.path.join(bindir, parts[0])):
226             parts[0] = os.path.join(bindir, parts[0])
227         line = " ".join(parts)
228         return line
229
230     def check_run(self, line):
231         line = self._make_cmdline(line)
232         p = subprocess.Popen(line, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
233         retcode = p.wait()
234         if retcode:
235             raise BlackboxProcessError(retcode, line, p.stdout.read(), p.stderr.read())
236
237     def check_output(self, line):
238         line = self._make_cmdline(line)
239         p = subprocess.Popen(line, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True, close_fds=True)
240         retcode = p.wait()
241         if retcode:
242             raise BlackboxProcessError(retcode, line, p.stdout.read(), p.stderr.read())
243         return p.stdout.read()
244
245
246 def connect_samdb(samdb_url, lp=None, session_info=None, credentials=None,
247                   flags=0, ldb_options=None, ldap_only=False, global_schema=True):
248     """Create SamDB instance and connects to samdb_url database.
249
250     :param samdb_url: Url for database to connect to.
251     :param lp: Optional loadparm object
252     :param session_info: Optional session information
253     :param credentials: Optional credentials, defaults to anonymous.
254     :param flags: Optional LDB flags
255     :param ldap_only: If set, only remote LDAP connection will be created.
256     :param global_schema: Whether to use global schema.
257
258     Added value for tests is that we have a shorthand function
259     to make proper URL for ldb.connect() while using default
260     parameters for connection based on test environment
261     """
262     if not "://" in samdb_url:
263         if not ldap_only and os.path.isfile(samdb_url):
264             samdb_url = "tdb://%s" % samdb_url
265         else:
266             samdb_url = "ldap://%s" % samdb_url
267     # use 'paged_search' module when connecting remotely
268     if samdb_url.startswith("ldap://"):
269         ldb_options = ["modules:paged_searches"]
270     elif ldap_only:
271         raise AssertionError("Trying to connect to %s while remote "
272                              "connection is required" % samdb_url)
273
274     # set defaults for test environment
275     if lp is None:
276         lp = env_loadparm()
277     if session_info is None:
278         session_info = samba.auth.system_session(lp)
279     if credentials is None:
280         credentials = cmdline_credentials
281
282     return SamDB(url=samdb_url,
283                  lp=lp,
284                  session_info=session_info,
285                  credentials=credentials,
286                  flags=flags,
287                  options=ldb_options,
288                  global_schema=global_schema)
289
290
291 def connect_samdb_ex(samdb_url, lp=None, session_info=None, credentials=None,
292                      flags=0, ldb_options=None, ldap_only=False):
293     """Connects to samdb_url database
294
295     :param samdb_url: Url for database to connect to.
296     :param lp: Optional loadparm object
297     :param session_info: Optional session information
298     :param credentials: Optional credentials, defaults to anonymous.
299     :param flags: Optional LDB flags
300     :param ldap_only: If set, only remote LDAP connection will be created.
301     :return: (sam_db_connection, rootDse_record) tuple
302     """
303     sam_db = connect_samdb(samdb_url, lp, session_info, credentials,
304                            flags, ldb_options, ldap_only)
305     # fetch RootDse
306     res = sam_db.search(base="", expression="", scope=ldb.SCOPE_BASE,
307                         attrs=["*"])
308     return (sam_db, res[0])
309
310
311 def connect_samdb_env(env_url, env_username, env_password, lp=None):
312     """Connect to SamDB by getting URL and Credentials from environment
313
314     :param env_url: Environment variable name to get lsb url from
315     :param env_username: Username environment variable
316     :param env_password: Password environment variable
317     :return: sam_db_connection
318     """
319     samdb_url = env_get_var_value(env_url)
320     creds = credentials.Credentials()
321     if lp is None:
322         # guess Credentials parameters here. Otherwise workstation
323         # and domain fields are NULL and gencache code segfalts
324         lp = param.LoadParm()
325         creds.guess(lp)
326     creds.set_username(env_get_var_value(env_username))
327     creds.set_password(env_get_var_value(env_password))
328     return connect_samdb(samdb_url, credentials=creds, lp=lp)
329
330
331 def delete_force(samdb, dn):
332     try:
333         samdb.delete(dn)
334     except ldb.LdbError, (num, errstr):
335         assert num == ldb.ERR_NO_SUCH_OBJECT, "ldb.delete() failed: %s" % errstr