PEP8: fix E225: missing whitespace around operator
[nivanova/samba-autobuild/.git] / python / samba / tests / __init__.py
index b73308ecd49629ea963d634c0007d8616e4e3cde..60c18c7323465704fa8bd7b9349549d5014e9ebe 100644 (file)
@@ -1,5 +1,6 @@
 # Unix SMB/CIFS implementation.
 # Copyright (C) Jelmer Vernooij <jelmer@samba.org> 2007-2010
+# Copyright (C) Stefan Metzmacher 2014,2015
 #
 # This program is free software; you can redistribute it and/or modify
 # it under the terms of the GNU General Public License as published by
 """Samba Python tests."""
 
 import os
+import tempfile
 import ldb
 import samba
-import samba.auth
 from samba import param
-from samba.samdb import SamDB
 from samba import credentials
+from samba.credentials import Credentials
+from samba import gensec
+import socket
+import struct
 import subprocess
+import sys
 import tempfile
 import unittest
+import re
+import samba.auth
+import samba.dcerpc.base
+from samba.compat import PY3, text_type
+from samba.compat import string_types
+from random import randint
+try:
+    from samba.samdb import SamDB
+except ImportError:
+    # We are built without samdb support,
+    # imitate it so that connect_samdb() can recover
+    def SamDB(*args, **kwargs):
+        return None
+
+import samba.ndr
+import samba.dcerpc.dcerpc
+import samba.dcerpc.epmapper
 
 try:
     from unittest import SkipTest
@@ -34,6 +56,7 @@ except ImportError:
     class SkipTest(Exception):
         """Test skipped."""
 
+HEXDUMP_FILTER = bytearray([x if ((len(repr(chr(x))) == 3) and (x < 127)) else ord('.') for x in range(256)])
 
 class TestCase(unittest.TestCase):
     """A Samba test case."""
@@ -53,32 +76,192 @@ class TestCase(unittest.TestCase):
     def get_credentials(self):
         return cmdline_credentials
 
+    def get_creds_ccache_name(self):
+        creds = self.get_credentials()
+        ccache = creds.get_named_ccache(self.get_loadparm())
+        ccache_name = ccache.get_name()
+
+        return ccache_name
+
+    def hexdump(self, src):
+        N = 0
+        result = ''
+        is_string = isinstance(src, string_types)
+        while src:
+            ll = src[:8]
+            lr = src[8:16]
+            src = src[16:]
+            if is_string:
+                hl = ' '.join(["%02X" % ord(x) for x in ll])
+                hr = ' '.join(["%02X" % ord(x) for x in lr])
+                ll = ll.translate(HEXDUMP_FILTER)
+                lr = lr.translate(HEXDUMP_FILTER)
+            else:
+                hl = ' '.join(["%02X" % x for x in ll])
+                hr = ' '.join(["%02X" % x for x in lr])
+                ll = ll.translate(HEXDUMP_FILTER).decode('utf8')
+                lr = lr.translate(HEXDUMP_FILTER).decode('utf8')
+            result += "[%04X] %-*s  %-*s  %s %s\n" % (N, 8*3, hl, 8*3, hr, ll, lr)
+            N += 16
+        return result
+
+    def insta_creds(self, template=None, username=None, userpass=None, kerberos_state=None):
+
+        if template is None:
+            assert template is not None
+
+        if username is not None:
+            assert userpass is not None
+
+        if username is None:
+            assert userpass is None
+
+            username = template.get_username()
+            userpass = template.get_password()
+
+        if kerberos_state is None:
+            kerberos_state = template.get_kerberos_state()
+
+        # get a copy of the global creds or a the passed in creds
+        c = Credentials()
+        c.set_username(username)
+        c.set_password(userpass)
+        c.set_domain(template.get_domain())
+        c.set_realm(template.get_realm())
+        c.set_workstation(template.get_workstation())
+        c.set_gensec_features(c.get_gensec_features()
+                              | gensec.FEATURE_SEAL)
+        c.set_kerberos_state(kerberos_state)
+        return c
+
+
+
     # These functions didn't exist before Python2.7:
-    if not getattr(unittest.TestCase, "skipTest", None):
+    if sys.version_info < (2, 7):
+        import warnings
+
         def skipTest(self, reason):
             raise SkipTest(reason)
 
-    if not getattr(unittest.TestCase, "assertIs", None):
-        def assertIs(self, a, b):
-            self.assertTrue(a is b)
+        def assertIn(self, member, container, msg=None):
+            self.assertTrue(member in container, msg)
+
+        def assertIs(self, a, b, msg=None):
+            self.assertTrue(a is b, msg)
 
-    if not getattr(unittest.TestCase, "assertIsNot", None):
-        def assertIsNot(self, a, b):
-            self.assertTrue(a is not b)
+        def assertIsNot(self, a, b, msg=None):
+            self.assertTrue(a is not b, msg)
 
-    if not getattr(unittest.TestCase, "assertIsInstance", None):
-        def assertIsInstance(self, a, b):
-            self.assertTrue(isinstance(a, b))
+        def assertIsNotNone(self, a, msg=None):
+            self.assertTrue(a is not None)
+
+        def assertIsInstance(self, a, b, msg=None):
+            self.assertTrue(isinstance(a, b), msg)
+
+        def assertIsNone(self, a, msg=None):
+            self.assertTrue(a is None, msg)
+
+        def assertGreater(self, a, b, msg=None):
+            self.assertTrue(a > b, msg)
+
+        def assertGreaterEqual(self, a, b, msg=None):
+            self.assertTrue(a >= b, msg)
+
+        def assertLess(self, a, b, msg=None):
+            self.assertTrue(a < b, msg)
+
+        def assertLessEqual(self, a, b, msg=None):
+            self.assertTrue(a <= b, msg)
 
-    if not getattr(unittest.TestCase, "addCleanup", None):
         def addCleanup(self, fn, *args, **kwargs):
             self._cleanups = getattr(self, "_cleanups", []) + [
                 (fn, args, kwargs)]
 
-        def tearDown(self):
-            super(TestCase, self).tearDown()
-            for (fn, args, kwargs) in reversed(getattr(self, "_cleanups", [])):
-                fn(*args, **kwargs)
+        def assertRegexpMatches(self, text, regex, msg=None):
+            # PY3 note: Python 3 will never see this, but we use
+            # text_type for the benefit of linters.
+            if isinstance(regex, (str, text_type)):
+                regex = re.compile(regex)
+            if not regex.search(text):
+                self.fail(msg)
+
+        def _addSkip(self, result, reason):
+            addSkip = getattr(result, 'addSkip', None)
+            if addSkip is not None:
+                addSkip(self, reason)
+            else:
+                warnings.warn("TestResult has no addSkip method, skips not reported",
+                              RuntimeWarning, 2)
+                result.addSuccess(self)
+
+        def run(self, result=None):
+            if result is None: result = self.defaultTestResult()
+            result.startTest(self)
+            testMethod = getattr(self, self._testMethodName)
+            try:
+                try:
+                    self.setUp()
+                except SkipTest as e:
+                    self._addSkip(result, str(e))
+                    return
+                except KeyboardInterrupt:
+                    raise
+                except:
+                    result.addError(self, self._exc_info())
+                    return
+
+                ok = False
+                try:
+                    testMethod()
+                    ok = True
+                except SkipTest as e:
+                    self._addSkip(result, str(e))
+                    return
+                except self.failureException:
+                    result.addFailure(self, self._exc_info())
+                except KeyboardInterrupt:
+                    raise
+                except:
+                    result.addError(self, self._exc_info())
+
+                try:
+                    self.tearDown()
+                except SkipTest as e:
+                    self._addSkip(result, str(e))
+                except KeyboardInterrupt:
+                    raise
+                except:
+                    result.addError(self, self._exc_info())
+                    ok = False
+
+                for (fn, args, kwargs) in reversed(getattr(self, "_cleanups", [])):
+                    fn(*args, **kwargs)
+                if ok: result.addSuccess(self)
+            finally:
+                result.stopTest(self)
+
+    def assertStringsEqual(self, a, b, msg=None, strip=False):
+        """Assert equality between two strings and highlight any differences.
+        If strip is true, leading and trailing whitespace is ignored."""
+        if strip:
+            a = a.strip()
+            b = b.strip()
+
+        if a != b:
+            sys.stderr.write("The strings differ %s(lengths %d vs %d); "
+                             "a diff follows\n"
+                             % ('when stripped ' if strip else '',
+                                len(a), len(b),
+                                ))
+
+            from difflib import unified_diff
+            diff = unified_diff(a.splitlines(True),
+                                b.splitlines(True),
+                                'a', 'b')
+            for line in diff:
+                sys.stderr.write(line)
+
+            self.fail(msg)
 
 
 class LdbTestCase(TestCase):
@@ -86,7 +269,8 @@ class LdbTestCase(TestCase):
 
     def setUp(self):
         super(LdbTestCase, self).setUp()
-        self.filename = os.tempnam()
+        self.tempfile = tempfile.NamedTemporaryFile(delete=False)
+        self.filename = self.tempfile.name
         self.ldb = samba.Ldb(self.filename)
 
     def set_modules(self, modules=[]):
@@ -120,13 +304,16 @@ def env_loadparm():
     return lp
 
 
-def env_get_var_value(var_name):
+def env_get_var_value(var_name, allow_missing=False):
     """Returns value for variable in os.environ
 
     Function throws AssertionError if variable is defined.
     Unit-test based python tests require certain input params
     to be set in environment, otherwise they can't be run
     """
+    if allow_missing:
+        if var_name not in os.environ.keys():
+            return None
     assert var_name in os.environ.keys(), "Please supply %s in environment" % var_name
     return os.environ[var_name]
 
@@ -157,17 +344,22 @@ class BlackboxProcessError(Exception):
     (S.stderr)
     """
 
-    def __init__(self, returncode, cmd, stdout, stderr):
+    def __init__(self, returncode, cmd, stdout, stderr, msg=None):
         self.returncode = returncode
         self.cmd = cmd
         self.stdout = stdout
         self.stderr = stderr
+        self.msg = msg
 
     def __str__(self):
-        return "Command '%s'; exit status %d; stdout: '%s'; stderr: '%s'" % (self.cmd, self.returncode,
-                                                                             self.stdout, self.stderr)
+        s = ("Command '%s'; exit status %d; stdout: '%s'; stderr: '%s'" %
+             (self.cmd, self.returncode, self.stdout, self.stderr))
+        if self.msg is not None:
+            s = "%s; message: %s" % (s, self.msg)
+
+        return s
 
-class BlackboxTestCase(TestCase):
+class BlackboxTestCase(TestCaseInTempDir):
     """Base test case for blackbox tests."""
 
     def _make_cmdline(self, line):
@@ -178,20 +370,32 @@ class BlackboxTestCase(TestCase):
         line = " ".join(parts)
         return line
 
-    def check_run(self, line):
+    def check_run(self, line, msg=None):
+        self.check_exit_code(line, 0, msg=msg)
+
+    def check_exit_code(self, line, expected, msg=None):
         line = self._make_cmdline(line)
-        p = subprocess.Popen(line, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
-        retcode = p.wait()
-        if retcode:
-            raise BlackboxProcessError(retcode, line, p.stdout.read(), p.stderr.read())
+        p = subprocess.Popen(line,
+                             stdout=subprocess.PIPE,
+                             stderr=subprocess.PIPE,
+                             shell=True)
+        stdoutdata, stderrdata = p.communicate()
+        retcode = p.returncode
+        if retcode != expected:
+            raise BlackboxProcessError(retcode,
+                                       line,
+                                       stdoutdata,
+                                       stderrdata,
+                                       msg)
 
     def check_output(self, line):
         line = self._make_cmdline(line)
         p = subprocess.Popen(line, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True, close_fds=True)
-        retcode = p.wait()
+        stdoutdata, stderrdata = p.communicate()
+        retcode = p.returncode
         if retcode:
-            raise BlackboxProcessError(retcode, line, p.stdout.read(), p.stderr.read())
-        return p.stdout.read()
+            raise BlackboxProcessError(retcode, line, stdoutdata, stderrdata)
+        return stdoutdata
 
 
 def connect_samdb(samdb_url, lp=None, session_info=None, credentials=None,
@@ -279,8 +483,21 @@ def connect_samdb_env(env_url, env_username, env_password, lp=None):
     return connect_samdb(samdb_url, credentials=creds, lp=lp)
 
 
-def delete_force(samdb, dn):
+def delete_force(samdb, dn, **kwargs):
     try:
-        samdb.delete(dn)
-    except ldb.LdbError, (num, errstr):
+        samdb.delete(dn, **kwargs)
+    except ldb.LdbError as error:
+        (num, errstr) = error.args
         assert num == ldb.ERR_NO_SUCH_OBJECT, "ldb.delete() failed: %s" % errstr
+
+def create_test_ou(samdb, name):
+    """Creates a unique OU for the test"""
+
+    # Add some randomness to the test OU. Replication between the testenvs is
+    # constantly happening in the background. Deletion of the last test's
+    # objects can be slow to replicate out. So the OU created by a previous
+    # testenv may still exist at the point that tests start on another testenv.
+    rand = randint(1, 10000000)
+    dn = ldb.Dn(samdb, "OU=%s%d,%s" % (name, rand, samdb.get_default_basedn()))
+    samdb.add({"dn": dn, "objectclass": "organizationalUnit"})
+    return dn