python/samba/tests: fix SamDB dummy replacement
[nivanova/samba-autobuild/.git] / python / samba / tests / __init__.py
index 21a55c24735d03231c09622e61887b70641f65b9..97ce8d1c495be8dad53bd2e9cc2a99a5154f05ec 100644 (file)
 """Samba Python tests."""
 
 import os
+import tempfile
 import ldb
 import samba
 from samba import param
 from samba import credentials
 from samba.credentials import Credentials
+from samba import gensec
 import socket
 import struct
 import subprocess
 import sys
 import tempfile
 import unittest
+import re
 import samba.auth
 import samba.dcerpc.base
-from samba.compat import PY3
-if not PY3:
-    # Py2 only
+from samba.compat import PY3, text_type
+from samba.compat import string_types
+from random import randint
+try:
     from samba.samdb import SamDB
-    import samba.ndr
-    import samba.dcerpc.dcerpc
-    import samba.dcerpc.epmapper
-    from samba import gensec
+except ImportError:
+    # We are built without samdb support,
+    # imitate it so that connect_samdb() can recover
+    def SamDB(*args, **kwargs):
+        return None
+
+import samba.ndr
+import samba.dcerpc.dcerpc
+import samba.dcerpc.epmapper
 
 try:
     from unittest import SkipTest
@@ -47,7 +56,7 @@ except ImportError:
     class SkipTest(Exception):
         """Test skipped."""
 
-HEXDUMP_FILTER=''.join([(len(repr(chr(x)))==3) and chr(x) or '.' for x in range(256)])
+HEXDUMP_FILTER=bytearray([x if ((len(repr(chr(x)))==3) and (x < 127)) else ord('.') for x in range(256)])
 
 class TestCase(unittest.TestCase):
     """A Samba test case."""
@@ -67,21 +76,66 @@ class TestCase(unittest.TestCase):
     def get_credentials(self):
         return cmdline_credentials
 
+    def get_creds_ccache_name(self):
+        creds = self.get_credentials()
+        ccache = creds.get_named_ccache(self.get_loadparm())
+        ccache_name = ccache.get_name()
+
+        return ccache_name
+
     def hexdump(self, src):
         N = 0
         result = ''
+        is_string = isinstance(src, string_types)
         while src:
             ll = src[:8]
             lr = src[8:16]
             src = src[16:]
-            hl = ' '.join(["%02X" % ord(x) for x in ll])
-            hr = ' '.join(["%02X" % ord(x) for x in lr])
-            ll = ll.translate(HEXDUMP_FILTER)
-            lr = lr.translate(HEXDUMP_FILTER)
+            if is_string:
+                hl = ' '.join(["%02X" % ord(x) for x in ll])
+                hr = ' '.join(["%02X" % ord(x) for x in lr])
+                ll = ll.translate(HEXDUMP_FILTER)
+                lr = lr.translate(HEXDUMP_FILTER)
+            else:
+                hl = ' '.join(["%02X" % x for x in ll])
+                hr = ' '.join(["%02X" % x for x in lr])
+                ll = ll.translate(HEXDUMP_FILTER).decode('utf8')
+                lr = lr.translate(HEXDUMP_FILTER).decode('utf8')
             result += "[%04X] %-*s  %-*s  %s %s\n" % (N, 8*3, hl, 8*3, hr, ll, lr)
             N += 16
         return result
 
+    def insta_creds(self, template=None, username=None, userpass=None, kerberos_state=None):
+
+        if template is None:
+            assert template is not None
+
+        if username is not None:
+            assert userpass is not None
+
+        if username is None:
+            assert userpass is None
+
+            username = template.get_username()
+            userpass = template.get_password()
+
+        if kerberos_state is None:
+            kerberos_state = template.get_kerberos_state()
+
+        # get a copy of the global creds or a the passed in creds
+        c = Credentials()
+        c.set_username(username)
+        c.set_password(userpass)
+        c.set_domain(template.get_domain())
+        c.set_realm(template.get_realm())
+        c.set_workstation(template.get_workstation())
+        c.set_gensec_features(c.get_gensec_features()
+                              | gensec.FEATURE_SEAL)
+        c.set_kerberos_state(kerberos_state)
+        return c
+
+
+
     # These functions didn't exist before Python2.7:
     if sys.version_info < (2, 7):
         import warnings
@@ -123,6 +177,14 @@ class TestCase(unittest.TestCase):
             self._cleanups = getattr(self, "_cleanups", []) + [
                 (fn, args, kwargs)]
 
+        def assertRegexpMatches(self, text, regex, msg=None):
+            # PY3 note: Python 3 will never see this, but we use
+            # text_type for the benefit of linters.
+            if isinstance(regex, (str, text_type)):
+                regex = re.compile(regex)
+            if not regex.search(text):
+                self.fail(msg)
+
         def _addSkip(self, result, reason):
             addSkip = getattr(result, 'addSkip', None)
             if addSkip is not None:
@@ -178,13 +240,37 @@ class TestCase(unittest.TestCase):
             finally:
                 result.stopTest(self)
 
+    def assertStringsEqual(self, a, b, msg=None, strip=False):
+        """Assert equality between two strings and highlight any differences.
+        If strip is true, leading and trailing whitespace is ignored."""
+        if strip:
+            a = a.strip()
+            b = b.strip()
+
+        if a != b:
+            sys.stderr.write("The strings differ %s(lengths %d vs %d); "
+                             "a diff follows\n"
+                             % ('when stripped ' if strip else '',
+                                len(a), len(b),
+                             ))
+
+            from difflib import unified_diff
+            diff = unified_diff(a.splitlines(True),
+                                b.splitlines(True),
+                                'a', 'b')
+            for line in diff:
+                sys.stderr.write(line)
+
+            self.fail(msg)
+
 
 class LdbTestCase(TestCase):
     """Trivial test case for running tests against a LDB."""
 
     def setUp(self):
         super(LdbTestCase, self).setUp()
-        self.filename = os.tempnam()
+        self.tempfile = tempfile.NamedTemporaryFile(delete=False)
+        self.filename = self.tempfile.name
         self.ldb = samba.Ldb(self.filename)
 
     def set_modules(self, modules=[]):
@@ -258,15 +344,20 @@ class BlackboxProcessError(Exception):
     (S.stderr)
     """
 
-    def __init__(self, returncode, cmd, stdout, stderr):
+    def __init__(self, returncode, cmd, stdout, stderr, msg=None):
         self.returncode = returncode
         self.cmd = cmd
         self.stdout = stdout
         self.stderr = stderr
+        self.msg = msg
 
     def __str__(self):
-        return "Command '%s'; exit status %d; stdout: '%s'; stderr: '%s'" % (self.cmd, self.returncode,
-                                                                             self.stdout, self.stderr)
+        s = ("Command '%s'; exit status %d; stdout: '%s'; stderr: '%s'" %
+             (self.cmd, self.returncode, self.stdout, self.stderr))
+        if self.msg is not None:
+            s = "%s; message: %s" % (s, self.msg)
+
+        return s
 
 class BlackboxTestCase(TestCaseInTempDir):
     """Base test case for blackbox tests."""
@@ -279,20 +370,32 @@ class BlackboxTestCase(TestCaseInTempDir):
         line = " ".join(parts)
         return line
 
-    def check_run(self, line):
+    def check_run(self, line, msg=None):
+        self.check_exit_code(line, 0, msg=msg)
+
+    def check_exit_code(self, line, expected, msg=None):
         line = self._make_cmdline(line)
-        p = subprocess.Popen(line, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
-        retcode = p.wait()
-        if retcode:
-            raise BlackboxProcessError(retcode, line, p.stdout.read(), p.stderr.read())
+        p = subprocess.Popen(line,
+                             stdout=subprocess.PIPE,
+                             stderr=subprocess.PIPE,
+                             shell=True)
+        stdoutdata, stderrdata = p.communicate()
+        retcode = p.returncode
+        if retcode != expected:
+            raise BlackboxProcessError(retcode,
+                                       line,
+                                       stdoutdata,
+                                       stderrdata,
+                                       msg)
 
     def check_output(self, line):
         line = self._make_cmdline(line)
         p = subprocess.Popen(line, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True, close_fds=True)
-        retcode = p.wait()
+        stdoutdata, stderrdata = p.communicate()
+        retcode = p.returncode
         if retcode:
-            raise BlackboxProcessError(retcode, line, p.stdout.read(), p.stderr.read())
-        return p.stdout.read()
+            raise BlackboxProcessError(retcode, line, stdoutdata, stderrdata)
+        return stdoutdata
 
 
 def connect_samdb(samdb_url, lp=None, session_info=None, credentials=None,
@@ -380,9 +483,21 @@ def connect_samdb_env(env_url, env_username, env_password, lp=None):
     return connect_samdb(samdb_url, credentials=creds, lp=lp)
 
 
-def delete_force(samdb, dn):
+def delete_force(samdb, dn, **kwargs):
     try:
-        samdb.delete(dn)
+        samdb.delete(dn, **kwargs)
     except ldb.LdbError as error:
         (num, errstr) = error.args
         assert num == ldb.ERR_NO_SUCH_OBJECT, "ldb.delete() failed: %s" % errstr
+
+def create_test_ou(samdb, name):
+    """Creates a unique OU for the test"""
+
+    # Add some randomness to the test OU. Replication between the testenvs is
+    # constantly happening in the background. Deletion of the last test's
+    # objects can be slow to replicate out. So the OU created by a previous
+    # testenv may still exist at the point that tests start on another testenv.
+    rand = randint(1, 10000000)
+    dn = ldb.Dn(samdb, "OU=%s%d,%s" %(name, rand, samdb.get_default_basedn()))
+    samdb.add({ "dn": dn, "objectclass": "organizationalUnit"})
+    return dn