99cea490a1c2c6824d984ad40c1fc13713239402
[samba.git] / python / samba / tests / __init__.py
1 # Unix SMB/CIFS implementation.
2 # Copyright (C) Jelmer Vernooij <jelmer@samba.org> 2007-2010
3 #
4 # This program is free software; you can redistribute it and/or modify
5 # it under the terms of the GNU General Public License as published by
6 # the Free Software Foundation; either version 3 of the License, or
7 # (at your option) any later version.
8 #
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12 # GNU General Public License for more details.
13 #
14 # You should have received a copy of the GNU General Public License
15 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
16 #
17
18 """Samba Python tests."""
19
20 import os
21 import ldb
22 import samba
23 import samba.auth
24 from samba import param
25 from samba.samdb import SamDB
26 from samba import credentials
27 import subprocess
28 import tempfile
29 import unittest
30
31 try:
32     from unittest import SkipTest
33 except ImportError:
34     class SkipTest(Exception):
35         """Test skipped."""
36
37
38 class TestCase(unittest.TestCase):
39     """A Samba test case."""
40
41     def setUp(self):
42         super(TestCase, self).setUp()
43         test_debug_level = os.getenv("TEST_DEBUG_LEVEL")
44         if test_debug_level is not None:
45             test_debug_level = int(test_debug_level)
46             self._old_debug_level = samba.get_debug_level()
47             samba.set_debug_level(test_debug_level)
48             self.addCleanup(samba.set_debug_level, test_debug_level)
49
50     def get_loadparm(self):
51         return env_loadparm()
52
53     def get_credentials(self):
54         return cmdline_credentials
55
56     # These functions didn't exist before Python2.7:
57     if not getattr(unittest.TestCase, "skipTest", None):
58         def skipTest(self, reason):
59             raise SkipTest(reason)
60
61     if not getattr(unittest.TestCase, "assertIs", None):
62         def assertIs(self, a, b):
63             self.assertTrue(a is b)
64
65     if not getattr(unittest.TestCase, "assertIsNot", None):
66         def assertIsNot(self, a, b):
67             self.assertTrue(a is not b)
68
69     if not getattr(unittest.TestCase, "assertIsInstance", None):
70         def assertIsInstance(self, a, b):
71             self.assertTrue(isinstance(a, b))
72
73     if not getattr(unittest.TestCase, "addCleanup", None):
74         def addCleanup(self, fn, *args, **kwargs):
75             self._cleanups = getattr(self, "_cleanups", []) + [
76                 (fn, args, kwargs)]
77
78         def run(self, result=None):
79             ret = super(TestCase, self).run(result=result)
80             for (fn, args, kwargs) in reversed(getattr(self, "_cleanups", [])):
81                 fn(*args, **kwargs)
82             return ret
83
84
85 class LdbTestCase(TestCase):
86     """Trivial test case for running tests against a LDB."""
87
88     def setUp(self):
89         super(LdbTestCase, self).setUp()
90         self.filename = os.tempnam()
91         self.ldb = samba.Ldb(self.filename)
92
93     def set_modules(self, modules=[]):
94         """Change the modules for this Ldb."""
95         m = ldb.Message()
96         m.dn = ldb.Dn(self.ldb, "@MODULES")
97         m["@LIST"] = ",".join(modules)
98         self.ldb.add(m)
99         self.ldb = samba.Ldb(self.filename)
100
101
102 class TestCaseInTempDir(TestCase):
103
104     def setUp(self):
105         super(TestCaseInTempDir, self).setUp()
106         self.tempdir = tempfile.mkdtemp()
107         self.addCleanup(self._remove_tempdir)
108
109     def _remove_tempdir(self):
110         self.assertEquals([], os.listdir(self.tempdir))
111         os.rmdir(self.tempdir)
112         self.tempdir = None
113
114
115 def env_loadparm():
116     lp = param.LoadParm()
117     try:
118         lp.load(os.environ["SMB_CONF_PATH"])
119     except KeyError:
120         raise KeyError("SMB_CONF_PATH not set")
121     return lp
122
123
124 def env_get_var_value(var_name):
125     """Returns value for variable in os.environ
126
127     Function throws AssertionError if variable is defined.
128     Unit-test based python tests require certain input params
129     to be set in environment, otherwise they can't be run
130     """
131     assert var_name in os.environ.keys(), "Please supply %s in environment" % var_name
132     return os.environ[var_name]
133
134
135 cmdline_credentials = None
136
137 class RpcInterfaceTestCase(TestCase):
138     """DCE/RPC Test case."""
139
140
141 class ValidNetbiosNameTests(TestCase):
142
143     def test_valid(self):
144         self.assertTrue(samba.valid_netbios_name("FOO"))
145
146     def test_too_long(self):
147         self.assertFalse(samba.valid_netbios_name("FOO"*10))
148
149     def test_invalid_characters(self):
150         self.assertFalse(samba.valid_netbios_name("*BLA"))
151
152
153 class BlackboxProcessError(Exception):
154     """This is raised when check_output() process returns a non-zero exit status
155
156     Exception instance should contain the exact exit code (S.returncode),
157     command line (S.cmd), process output (S.stdout) and process error stream
158     (S.stderr)
159     """
160
161     def __init__(self, returncode, cmd, stdout, stderr):
162         self.returncode = returncode
163         self.cmd = cmd
164         self.stdout = stdout
165         self.stderr = stderr
166
167     def __str__(self):
168         return "Command '%s'; exit status %d; stdout: '%s'; stderr: '%s'" % (self.cmd, self.returncode,
169                                                                              self.stdout, self.stderr)
170
171 class BlackboxTestCase(TestCase):
172     """Base test case for blackbox tests."""
173
174     def _make_cmdline(self, line):
175         bindir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../../bin"))
176         parts = line.split(" ")
177         if os.path.exists(os.path.join(bindir, parts[0])):
178             parts[0] = os.path.join(bindir, parts[0])
179         line = " ".join(parts)
180         return line
181
182     def check_run(self, line):
183         line = self._make_cmdline(line)
184         p = subprocess.Popen(line, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
185         retcode = p.wait()
186         if retcode:
187             raise BlackboxProcessError(retcode, line, p.stdout.read(), p.stderr.read())
188
189     def check_output(self, line):
190         line = self._make_cmdline(line)
191         p = subprocess.Popen(line, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True, close_fds=True)
192         retcode = p.wait()
193         if retcode:
194             raise BlackboxProcessError(retcode, line, p.stdout.read(), p.stderr.read())
195         return p.stdout.read()
196
197
198 def connect_samdb(samdb_url, lp=None, session_info=None, credentials=None,
199                   flags=0, ldb_options=None, ldap_only=False, global_schema=True):
200     """Create SamDB instance and connects to samdb_url database.
201
202     :param samdb_url: Url for database to connect to.
203     :param lp: Optional loadparm object
204     :param session_info: Optional session information
205     :param credentials: Optional credentials, defaults to anonymous.
206     :param flags: Optional LDB flags
207     :param ldap_only: If set, only remote LDAP connection will be created.
208     :param global_schema: Whether to use global schema.
209
210     Added value for tests is that we have a shorthand function
211     to make proper URL for ldb.connect() while using default
212     parameters for connection based on test environment
213     """
214     if not "://" in samdb_url:
215         if not ldap_only and os.path.isfile(samdb_url):
216             samdb_url = "tdb://%s" % samdb_url
217         else:
218             samdb_url = "ldap://%s" % samdb_url
219     # use 'paged_search' module when connecting remotely
220     if samdb_url.startswith("ldap://"):
221         ldb_options = ["modules:paged_searches"]
222     elif ldap_only:
223         raise AssertionError("Trying to connect to %s while remote "
224                              "connection is required" % samdb_url)
225
226     # set defaults for test environment
227     if lp is None:
228         lp = env_loadparm()
229     if session_info is None:
230         session_info = samba.auth.system_session(lp)
231     if credentials is None:
232         credentials = cmdline_credentials
233
234     return SamDB(url=samdb_url,
235                  lp=lp,
236                  session_info=session_info,
237                  credentials=credentials,
238                  flags=flags,
239                  options=ldb_options,
240                  global_schema=global_schema)
241
242
243 def connect_samdb_ex(samdb_url, lp=None, session_info=None, credentials=None,
244                      flags=0, ldb_options=None, ldap_only=False):
245     """Connects to samdb_url database
246
247     :param samdb_url: Url for database to connect to.
248     :param lp: Optional loadparm object
249     :param session_info: Optional session information
250     :param credentials: Optional credentials, defaults to anonymous.
251     :param flags: Optional LDB flags
252     :param ldap_only: If set, only remote LDAP connection will be created.
253     :return: (sam_db_connection, rootDse_record) tuple
254     """
255     sam_db = connect_samdb(samdb_url, lp, session_info, credentials,
256                            flags, ldb_options, ldap_only)
257     # fetch RootDse
258     res = sam_db.search(base="", expression="", scope=ldb.SCOPE_BASE,
259                         attrs=["*"])
260     return (sam_db, res[0])
261
262
263 def connect_samdb_env(env_url, env_username, env_password, lp=None):
264     """Connect to SamDB by getting URL and Credentials from environment
265
266     :param env_url: Environment variable name to get lsb url from
267     :param env_username: Username environment variable
268     :param env_password: Password environment variable
269     :return: sam_db_connection
270     """
271     samdb_url = env_get_var_value(env_url)
272     creds = credentials.Credentials()
273     if lp is None:
274         # guess Credentials parameters here. Otherwise workstation
275         # and domain fields are NULL and gencache code segfalts
276         lp = param.LoadParm()
277         creds.guess(lp)
278     creds.set_username(env_get_var_value(env_username))
279     creds.set_password(env_get_var_value(env_password))
280     return connect_samdb(samdb_url, credentials=creds, lp=lp)
281
282
283 def delete_force(samdb, dn):
284     try:
285         samdb.delete(dn)
286     except ldb.LdbError, (num, errstr):
287         assert num == ldb.ERR_NO_SUCH_OBJECT, "ldb.delete() failed: %s" % errstr