python/samba/tests: make sure samba-tool is called with ${PYTHON}
[amitay/samba.git] / python / samba / tests / __init__.py
index 984b1bf0660fc97f0ad4eb3243217e144f55a43c..ca278b5d1c7a0b8a84ebead62f3d4ff1767d466f 100644 (file)
@@ -1,5 +1,6 @@
 # Unix SMB/CIFS implementation.
 # Copyright (C) Jelmer Vernooij <jelmer@samba.org> 2007-2010
+# Copyright (C) Stefan Metzmacher 2014,2015
 #
 # This program is free software; you can redistribute it and/or modify
 # it under the terms of the GNU General Public License as published by
 """Samba Python tests."""
 
 import os
+import tempfile
 import ldb
 import samba
-import samba.auth
 from samba import param
-from samba.samdb import SamDB
 from samba import credentials
+from samba.credentials import Credentials
+from samba import gensec
+import socket
+import struct
 import subprocess
 import sys
 import tempfile
 import unittest
+import re
+import samba.auth
+import samba.dcerpc.base
+from samba.compat import PY3, text_type
+from samba.compat import string_types
+from random import randint
+from random import SystemRandom
+import string
+try:
+    from samba.samdb import SamDB
+except ImportError:
+    # We are built without samdb support,
+    # imitate it so that connect_samdb() can recover
+    def SamDB(*args, **kwargs):
+        return None
+
+import samba.ndr
+import samba.dcerpc.dcerpc
+import samba.dcerpc.epmapper
 
 try:
     from unittest import SkipTest
@@ -35,7 +58,8 @@ except ImportError:
     class SkipTest(Exception):
         """Test skipped."""
 
-HEXDUMP_FILTER=''.join([(len(repr(chr(x)))==3) and chr(x) or '.' for x in range(256)])
+HEXDUMP_FILTER = bytearray([x if ((len(repr(chr(x))) == 3) and (x < 127)) else ord('.') for x in range(256)])
+
 
 class TestCase(unittest.TestCase):
     """A Samba test case."""
@@ -55,17 +79,64 @@ class TestCase(unittest.TestCase):
     def get_credentials(self):
         return cmdline_credentials
 
-    def hexdump(self, src, length=8):
+    def get_creds_ccache_name(self):
+        creds = self.get_credentials()
+        ccache = creds.get_named_ccache(self.get_loadparm())
+        ccache_name = ccache.get_name()
+
+        return ccache_name
+
+    def hexdump(self, src):
         N = 0
         result = ''
+        is_string = isinstance(src, string_types)
         while src:
-            s, src = src[:length], src[length:]
-            hexa = ' '.join(["%02X" % ord(x) for x in s])
-            s = s.translate(HEXDUMP_FILTER)
-            result += "%04X   %-*s   %s\n" % (N, length*3, hexa, s)
-            N += length
+            ll = src[:8]
+            lr = src[8:16]
+            src = src[16:]
+            if is_string:
+                hl = ' '.join(["%02X" % ord(x) for x in ll])
+                hr = ' '.join(["%02X" % ord(x) for x in lr])
+                ll = ll.translate(HEXDUMP_FILTER)
+                lr = lr.translate(HEXDUMP_FILTER)
+            else:
+                hl = ' '.join(["%02X" % x for x in ll])
+                hr = ' '.join(["%02X" % x for x in lr])
+                ll = ll.translate(HEXDUMP_FILTER).decode('utf8')
+                lr = lr.translate(HEXDUMP_FILTER).decode('utf8')
+            result += "[%04X] %-*s  %-*s  %s %s\n" % (N, 8 * 3, hl, 8 * 3, hr, ll, lr)
+            N += 16
         return result
 
+    def insta_creds(self, template=None, username=None, userpass=None, kerberos_state=None):
+
+        if template is None:
+            assert template is not None
+
+        if username is not None:
+            assert userpass is not None
+
+        if username is None:
+            assert userpass is None
+
+            username = template.get_username()
+            userpass = template.get_password()
+
+        if kerberos_state is None:
+            kerberos_state = template.get_kerberos_state()
+
+        # get a copy of the global creds or a the passed in creds
+        c = Credentials()
+        c.set_username(username)
+        c.set_password(userpass)
+        c.set_domain(template.get_domain())
+        c.set_realm(template.get_realm())
+        c.set_workstation(template.get_workstation())
+        c.set_gensec_features(c.get_gensec_features()
+                              | gensec.FEATURE_SEAL)
+        c.set_kerberos_state(kerberos_state)
+        return c
+
     # These functions didn't exist before Python2.7:
     if sys.version_info < (2, 7):
         import warnings
@@ -91,10 +162,30 @@ class TestCase(unittest.TestCase):
         def assertIsNone(self, a, msg=None):
             self.assertTrue(a is None, msg)
 
+        def assertGreater(self, a, b, msg=None):
+            self.assertTrue(a > b, msg)
+
+        def assertGreaterEqual(self, a, b, msg=None):
+            self.assertTrue(a >= b, msg)
+
+        def assertLess(self, a, b, msg=None):
+            self.assertTrue(a < b, msg)
+
+        def assertLessEqual(self, a, b, msg=None):
+            self.assertTrue(a <= b, msg)
+
         def addCleanup(self, fn, *args, **kwargs):
             self._cleanups = getattr(self, "_cleanups", []) + [
                 (fn, args, kwargs)]
 
+        def assertRegexpMatches(self, text, regex, msg=None):
+            # PY3 note: Python 3 will never see this, but we use
+            # text_type for the benefit of linters.
+            if isinstance(regex, (str, text_type)):
+                regex = re.compile(regex)
+            if not regex.search(text):
+                self.fail(msg)
+
         def _addSkip(self, result, reason):
             addSkip = getattr(result, 'addSkip', None)
             if addSkip is not None:
@@ -105,13 +196,14 @@ class TestCase(unittest.TestCase):
                 result.addSuccess(self)
 
         def run(self, result=None):
-            if result is None: result = self.defaultTestResult()
+            if result is None:
+                result = self.defaultTestResult()
             result.startTest(self)
             testMethod = getattr(self, self._testMethodName)
             try:
                 try:
                     self.setUp()
-                except SkipTest, e:
+                except SkipTest as e:
                     self._addSkip(result, str(e))
                     return
                 except KeyboardInterrupt:
@@ -124,7 +216,7 @@ class TestCase(unittest.TestCase):
                 try:
                     testMethod()
                     ok = True
-                except SkipTest, e:
+                except SkipTest as e:
                     self._addSkip(result, str(e))
                     return
                 except self.failureException:
@@ -136,7 +228,7 @@ class TestCase(unittest.TestCase):
 
                 try:
                     self.tearDown()
-                except SkipTest, e:
+                except SkipTest as e:
                     self._addSkip(result, str(e))
                 except KeyboardInterrupt:
                     raise
@@ -146,17 +238,42 @@ class TestCase(unittest.TestCase):
 
                 for (fn, args, kwargs) in reversed(getattr(self, "_cleanups", [])):
                     fn(*args, **kwargs)
-                if ok: result.addSuccess(self)
+                if ok:
+                    result.addSuccess(self)
             finally:
                 result.stopTest(self)
 
+    def assertStringsEqual(self, a, b, msg=None, strip=False):
+        """Assert equality between two strings and highlight any differences.
+        If strip is true, leading and trailing whitespace is ignored."""
+        if strip:
+            a = a.strip()
+            b = b.strip()
+
+        if a != b:
+            sys.stderr.write("The strings differ %s(lengths %d vs %d); "
+                             "a diff follows\n"
+                             % ('when stripped ' if strip else '',
+                                len(a), len(b),
+                                ))
+
+            from difflib import unified_diff
+            diff = unified_diff(a.splitlines(True),
+                                b.splitlines(True),
+                                'a', 'b')
+            for line in diff:
+                sys.stderr.write(line)
+
+            self.fail(msg)
+
 
 class LdbTestCase(TestCase):
     """Trivial test case for running tests against a LDB."""
 
     def setUp(self):
         super(LdbTestCase, self).setUp()
-        self.filename = os.tempnam()
+        self.tempfile = tempfile.NamedTemporaryFile(delete=False)
+        self.filename = self.tempfile.name
         self.ldb = samba.Ldb(self.filename)
 
     def set_modules(self, modules=[]):
@@ -190,19 +307,23 @@ def env_loadparm():
     return lp
 
 
-def env_get_var_value(var_name):
+def env_get_var_value(var_name, allow_missing=False):
     """Returns value for variable in os.environ
 
     Function throws AssertionError if variable is defined.
     Unit-test based python tests require certain input params
     to be set in environment, otherwise they can't be run
     """
+    if allow_missing:
+        if var_name not in os.environ.keys():
+            return None
     assert var_name in os.environ.keys(), "Please supply %s in environment" % var_name
     return os.environ[var_name]
 
 
 cmdline_credentials = None
 
+
 class RpcInterfaceTestCase(TestCase):
     """DCE/RPC Test case."""
 
@@ -213,7 +334,7 @@ class ValidNetbiosNameTests(TestCase):
         self.assertTrue(samba.valid_netbios_name("FOO"))
 
     def test_too_long(self):
-        self.assertFalse(samba.valid_netbios_name("FOO"*10))
+        self.assertFalse(samba.valid_netbios_name("FOO" * 10))
 
     def test_invalid_characters(self):
         self.assertFalse(samba.valid_netbios_name("*BLA"))
@@ -227,41 +348,73 @@ class BlackboxProcessError(Exception):
     (S.stderr)
     """
 
-    def __init__(self, returncode, cmd, stdout, stderr):
+    def __init__(self, returncode, cmd, stdout, stderr, msg=None):
         self.returncode = returncode
         self.cmd = cmd
         self.stdout = stdout
         self.stderr = stderr
+        self.msg = msg
 
     def __str__(self):
-        return "Command '%s'; exit status %d; stdout: '%s'; stderr: '%s'" % (self.cmd, self.returncode,
-                                                                             self.stdout, self.stderr)
+        s = ("Command '%s'; exit status %d; stdout: '%s'; stderr: '%s'" %
+             (self.cmd, self.returncode, self.stdout, self.stderr))
+        if self.msg is not None:
+            s = "%s; message: %s" % (s, self.msg)
 
-class BlackboxTestCase(TestCase):
+        return s
+
+
+class BlackboxTestCase(TestCaseInTempDir):
     """Base test case for blackbox tests."""
 
     def _make_cmdline(self, line):
         bindir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../../bin"))
         parts = line.split(" ")
         if os.path.exists(os.path.join(bindir, parts[0])):
+            cmd = parts[0]
             parts[0] = os.path.join(bindir, parts[0])
+            if cmd == "samba-tool" and os.getenv("PYTHON", None):
+                parts = [os.environ["PYTHON"]] + parts
         line = " ".join(parts)
         return line
 
-    def check_run(self, line):
+    def check_run(self, line, msg=None):
+        self.check_exit_code(line, 0, msg=msg)
+
+    def check_exit_code(self, line, expected, msg=None):
         line = self._make_cmdline(line)
-        p = subprocess.Popen(line, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
-        retcode = p.wait()
-        if retcode:
-            raise BlackboxProcessError(retcode, line, p.stdout.read(), p.stderr.read())
+        p = subprocess.Popen(line,
+                             stdout=subprocess.PIPE,
+                             stderr=subprocess.PIPE,
+                             shell=True)
+        stdoutdata, stderrdata = p.communicate()
+        retcode = p.returncode
+        if retcode != expected:
+            raise BlackboxProcessError(retcode,
+                                       line,
+                                       stdoutdata,
+                                       stderrdata,
+                                       msg)
 
     def check_output(self, line):
         line = self._make_cmdline(line)
         p = subprocess.Popen(line, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True, close_fds=True)
-        retcode = p.wait()
+        stdoutdata, stderrdata = p.communicate()
+        retcode = p.returncode
         if retcode:
-            raise BlackboxProcessError(retcode, line, p.stdout.read(), p.stderr.read())
-        return p.stdout.read()
+            raise BlackboxProcessError(retcode, line, stdoutdata, stderrdata)
+        return stdoutdata
+
+    # Generate a random password that can be safely  passed on the command line
+    # i.e. it does not contain any shell meta characters.
+    def random_password(self, count=32):
+        password = SystemRandom().choice(string.ascii_uppercase)
+        password += SystemRandom().choice(string.digits)
+        password += SystemRandom().choice(string.ascii_lowercase)
+        password += ''.join(SystemRandom().choice(string.ascii_uppercase +
+                    string.ascii_lowercase +
+                    string.digits) for x in range(count - 3))
+        return password
 
 
 def connect_samdb(samdb_url, lp=None, session_info=None, credentials=None,
@@ -280,7 +433,7 @@ def connect_samdb(samdb_url, lp=None, session_info=None, credentials=None,
     to make proper URL for ldb.connect() while using default
     parameters for connection based on test environment
     """
-    if not "://" in samdb_url:
+    if "://" not in samdb_url:
         if not ldap_only and os.path.isfile(samdb_url):
             samdb_url = "tdb://%s" % samdb_url
         else:
@@ -349,8 +502,22 @@ def connect_samdb_env(env_url, env_username, env_password, lp=None):
     return connect_samdb(samdb_url, credentials=creds, lp=lp)
 
 
-def delete_force(samdb, dn):
+def delete_force(samdb, dn, **kwargs):
     try:
-        samdb.delete(dn)
-    except ldb.LdbError, (num, errstr):
+        samdb.delete(dn, **kwargs)
+    except ldb.LdbError as error:
+        (num, errstr) = error.args
         assert num == ldb.ERR_NO_SUCH_OBJECT, "ldb.delete() failed: %s" % errstr
+
+
+def create_test_ou(samdb, name):
+    """Creates a unique OU for the test"""
+
+    # Add some randomness to the test OU. Replication between the testenvs is
+    # constantly happening in the background. Deletion of the last test's
+    # objects can be slow to replicate out. So the OU created by a previous
+    # testenv may still exist at the point that tests start on another testenv.
+    rand = randint(1, 10000000)
+    dn = ldb.Dn(samdb, "OU=%s%d,%s" % (name, rand, samdb.get_default_basedn()))
+    samdb.add({"dn": dn, "objectclass": "organizationalUnit"})
+    return dn