blackbox tests: method to check specific exit codes
[samba.git] / python / samba / tests / __init__.py
index 8e662ed156425af69048d853776969e0222a8465..d012113cda682e73f17cb74f29edd6f86a0743c7 100644 (file)
@@ -1,5 +1,6 @@
 # Unix SMB/CIFS implementation.
 # Copyright (C) Jelmer Vernooij <jelmer@samba.org> 2007-2010
+# Copyright (C) Stefan Metzmacher 2014,2015
 #
 # This program is free software; you can redistribute it and/or modify
 # it under the terms of the GNU General Public License as published by
 import os
 import ldb
 import samba
-import samba.auth
 from samba import param
-from samba.samdb import SamDB
+from samba import credentials
+from samba.credentials import Credentials
+from samba import gensec
+import socket
+import struct
 import subprocess
+import sys
 import tempfile
-
-samba.ensure_external_module("mimeparse", "mimeparse")
-samba.ensure_external_module("extras", "extras")
-samba.ensure_external_module("testtools", "testtools")
-
-# Other modules import these two classes from here, for convenience:
-from testtools.testcase import (
-    TestCase as TesttoolsTestCase,
-    TestSkipped,
-    )
-
-
-class TestCase(TesttoolsTestCase):
+import unittest
+import samba.auth
+import samba.dcerpc.base
+from samba.compat import PY3
+if not PY3:
+    # Py2 only
+    from samba.samdb import SamDB
+    import samba.ndr
+    import samba.dcerpc.dcerpc
+    import samba.dcerpc.epmapper
+
+try:
+    from unittest import SkipTest
+except ImportError:
+    class SkipTest(Exception):
+        """Test skipped."""
+
+HEXDUMP_FILTER=''.join([(len(repr(chr(x)))==3) and chr(x) or '.' for x in range(256)])
+
+class TestCase(unittest.TestCase):
     """A Samba test case."""
 
     def setUp(self):
@@ -55,8 +67,157 @@ class TestCase(TesttoolsTestCase):
     def get_credentials(self):
         return cmdline_credentials
 
+    def get_creds_ccache_name(self):
+        creds = self.get_credentials()
+        ccache = creds.get_named_ccache(self.get_loadparm())
+        ccache_name = ccache.get_name()
 
-class LdbTestCase(TesttoolsTestCase):
+        return ccache_name
+
+    def hexdump(self, src):
+        N = 0
+        result = ''
+        while src:
+            ll = src[:8]
+            lr = src[8:16]
+            src = src[16:]
+            hl = ' '.join(["%02X" % ord(x) for x in ll])
+            hr = ' '.join(["%02X" % ord(x) for x in lr])
+            ll = ll.translate(HEXDUMP_FILTER)
+            lr = lr.translate(HEXDUMP_FILTER)
+            result += "[%04X] %-*s  %-*s  %s %s\n" % (N, 8*3, hl, 8*3, hr, ll, lr)
+            N += 16
+        return result
+
+    def insta_creds(self, template=None, username=None, userpass=None, kerberos_state=None):
+
+        if template is None:
+            assert template is not None
+
+        if username is not None:
+            assert userpass is not None
+
+        if username is None:
+            assert userpass is None
+
+            username = template.get_username()
+            userpass = template.get_password()
+
+        if kerberos_state is None:
+            kerberos_state = template.get_kerberos_state()
+
+        # get a copy of the global creds or a the passed in creds
+        c = Credentials()
+        c.set_username(username)
+        c.set_password(userpass)
+        c.set_domain(template.get_domain())
+        c.set_realm(template.get_realm())
+        c.set_workstation(template.get_workstation())
+        c.set_gensec_features(c.get_gensec_features()
+                              | gensec.FEATURE_SEAL)
+        c.set_kerberos_state(kerberos_state)
+        return c
+
+
+
+    # These functions didn't exist before Python2.7:
+    if sys.version_info < (2, 7):
+        import warnings
+
+        def skipTest(self, reason):
+            raise SkipTest(reason)
+
+        def assertIn(self, member, container, msg=None):
+            self.assertTrue(member in container, msg)
+
+        def assertIs(self, a, b, msg=None):
+            self.assertTrue(a is b, msg)
+
+        def assertIsNot(self, a, b, msg=None):
+            self.assertTrue(a is not b, msg)
+
+        def assertIsNotNone(self, a, msg=None):
+            self.assertTrue(a is not None)
+
+        def assertIsInstance(self, a, b, msg=None):
+            self.assertTrue(isinstance(a, b), msg)
+
+        def assertIsNone(self, a, msg=None):
+            self.assertTrue(a is None, msg)
+
+        def assertGreater(self, a, b, msg=None):
+            self.assertTrue(a > b, msg)
+
+        def assertGreaterEqual(self, a, b, msg=None):
+            self.assertTrue(a >= b, msg)
+
+        def assertLess(self, a, b, msg=None):
+            self.assertTrue(a < b, msg)
+
+        def assertLessEqual(self, a, b, msg=None):
+            self.assertTrue(a <= b, msg)
+
+        def addCleanup(self, fn, *args, **kwargs):
+            self._cleanups = getattr(self, "_cleanups", []) + [
+                (fn, args, kwargs)]
+
+        def _addSkip(self, result, reason):
+            addSkip = getattr(result, 'addSkip', None)
+            if addSkip is not None:
+                addSkip(self, reason)
+            else:
+                warnings.warn("TestResult has no addSkip method, skips not reported",
+                              RuntimeWarning, 2)
+                result.addSuccess(self)
+
+        def run(self, result=None):
+            if result is None: result = self.defaultTestResult()
+            result.startTest(self)
+            testMethod = getattr(self, self._testMethodName)
+            try:
+                try:
+                    self.setUp()
+                except SkipTest as e:
+                    self._addSkip(result, str(e))
+                    return
+                except KeyboardInterrupt:
+                    raise
+                except:
+                    result.addError(self, self._exc_info())
+                    return
+
+                ok = False
+                try:
+                    testMethod()
+                    ok = True
+                except SkipTest as e:
+                    self._addSkip(result, str(e))
+                    return
+                except self.failureException:
+                    result.addFailure(self, self._exc_info())
+                except KeyboardInterrupt:
+                    raise
+                except:
+                    result.addError(self, self._exc_info())
+
+                try:
+                    self.tearDown()
+                except SkipTest as e:
+                    self._addSkip(result, str(e))
+                except KeyboardInterrupt:
+                    raise
+                except:
+                    result.addError(self, self._exc_info())
+                    ok = False
+
+                for (fn, args, kwargs) in reversed(getattr(self, "_cleanups", [])):
+                    fn(*args, **kwargs)
+                if ok: result.addSuccess(self)
+            finally:
+                result.stopTest(self)
+
+
+class LdbTestCase(TestCase):
     """Trivial test case for running tests against a LDB."""
 
     def setUp(self):
@@ -91,17 +252,20 @@ def env_loadparm():
     try:
         lp.load(os.environ["SMB_CONF_PATH"])
     except KeyError:
-        raise Exception("SMB_CONF_PATH not set")
+        raise KeyError("SMB_CONF_PATH not set")
     return lp
 
 
-def env_get_var_value(var_name):
+def env_get_var_value(var_name, allow_missing=False):
     """Returns value for variable in os.environ
 
     Function throws AssertionError if variable is defined.
     Unit-test based python tests require certain input params
     to be set in environment, otherwise they can't be run
     """
+    if allow_missing:
+        if var_name not in os.environ.keys():
+            return None
     assert var_name in os.environ.keys(), "Please supply %s in environment" % var_name
     return os.environ[var_name]
 
@@ -142,7 +306,7 @@ class BlackboxProcessError(Exception):
         return "Command '%s'; exit status %d; stdout: '%s'; stderr: '%s'" % (self.cmd, self.returncode,
                                                                              self.stdout, self.stderr)
 
-class BlackboxTestCase(TestCase):
+class BlackboxTestCase(TestCaseInTempDir):
     """Base test case for blackbox tests."""
 
     def _make_cmdline(self, line):
@@ -154,11 +318,20 @@ class BlackboxTestCase(TestCase):
         return line
 
     def check_run(self, line):
+        self.check_exit_code(line, 0)
+
+    def check_exit_code(self, line, expected):
         line = self._make_cmdline(line)
-        p = subprocess.Popen(line, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
+        p = subprocess.Popen(line,
+                             stdout=subprocess.PIPE,
+                             stderr=subprocess.PIPE,
+                             shell=True)
         retcode = p.wait()
-        if retcode:
-            raise BlackboxProcessError(retcode, line, p.stdout.read(), p.stderr.read())
+        if retcode != expected:
+            raise BlackboxProcessError(retcode,
+                                       line,
+                                       p.stdout.read(),
+                                       p.stderr.read())
 
     def check_output(self, line):
         line = self._make_cmdline(line)
@@ -185,7 +358,6 @@ def connect_samdb(samdb_url, lp=None, session_info=None, credentials=None,
     to make proper URL for ldb.connect() while using default
     parameters for connection based on test environment
     """
-    samdb_url = samdb_url.lower()
     if not "://" in samdb_url:
         if not ldap_only and os.path.isfile(samdb_url):
             samdb_url = "tdb://%s" % samdb_url
@@ -235,8 +407,29 @@ def connect_samdb_ex(samdb_url, lp=None, session_info=None, credentials=None,
     return (sam_db, res[0])
 
 
-def delete_force(samdb, dn):
+def connect_samdb_env(env_url, env_username, env_password, lp=None):
+    """Connect to SamDB by getting URL and Credentials from environment
+
+    :param env_url: Environment variable name to get lsb url from
+    :param env_username: Username environment variable
+    :param env_password: Password environment variable
+    :return: sam_db_connection
+    """
+    samdb_url = env_get_var_value(env_url)
+    creds = credentials.Credentials()
+    if lp is None:
+        # guess Credentials parameters here. Otherwise workstation
+        # and domain fields are NULL and gencache code segfalts
+        lp = param.LoadParm()
+        creds.guess(lp)
+    creds.set_username(env_get_var_value(env_username))
+    creds.set_password(env_get_var_value(env_password))
+    return connect_samdb(samdb_url, credentials=creds, lp=lp)
+
+
+def delete_force(samdb, dn, **kwargs):
     try:
-        samdb.delete(dn)
-    except ldb.LdbError, (num, _):
-        assert(num == ldb.ERR_NO_SUCH_OBJECT)
+        samdb.delete(dn, **kwargs)
+    except ldb.LdbError as error:
+        (num, errstr) = error.args
+        assert num == ldb.ERR_NO_SUCH_OBJECT, "ldb.delete() failed: %s" % errstr