# Unix SMB/CIFS implementation.
# Copyright (C) Jelmer Vernooij <jelmer@samba.org> 2007-2010
+# Copyright (C) Stefan Metzmacher 2014,2015
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
"""Samba Python tests."""
import os
+import tempfile
import ldb
import samba
-import samba.auth
from samba import param
-from samba.samdb import SamDB
from samba import credentials
+from samba.credentials import Credentials
+from samba import gensec
+import socket
+import struct
import subprocess
+import sys
import tempfile
import unittest
+import re
+import samba.auth
+import samba.dcerpc.base
+from samba.compat import PY3, text_type
+from samba.compat import string_types
+from random import randint
+try:
+ from samba.samdb import SamDB
+except ImportError:
+ # We are built without samdb support,
+ # imitate it so that connect_samdb() can recover
+ def SamDB(*args, **kwargs):
+ return None
+
+import samba.ndr
+import samba.dcerpc.dcerpc
+import samba.dcerpc.epmapper
try:
from unittest import SkipTest
class SkipTest(Exception):
"""Test skipped."""
+HEXDUMP_FILTER = bytearray([x if ((len(repr(chr(x))) == 3) and (x < 127)) else ord('.') for x in range(256)])
class TestCase(unittest.TestCase):
"""A Samba test case."""
def get_credentials(self):
return cmdline_credentials
+ def get_creds_ccache_name(self):
+ creds = self.get_credentials()
+ ccache = creds.get_named_ccache(self.get_loadparm())
+ ccache_name = ccache.get_name()
+
+ return ccache_name
+
+ def hexdump(self, src):
+ N = 0
+ result = ''
+ is_string = isinstance(src, string_types)
+ while src:
+ ll = src[:8]
+ lr = src[8:16]
+ src = src[16:]
+ if is_string:
+ hl = ' '.join(["%02X" % ord(x) for x in ll])
+ hr = ' '.join(["%02X" % ord(x) for x in lr])
+ ll = ll.translate(HEXDUMP_FILTER)
+ lr = lr.translate(HEXDUMP_FILTER)
+ else:
+ hl = ' '.join(["%02X" % x for x in ll])
+ hr = ' '.join(["%02X" % x for x in lr])
+ ll = ll.translate(HEXDUMP_FILTER).decode('utf8')
+ lr = lr.translate(HEXDUMP_FILTER).decode('utf8')
+ result += "[%04X] %-*s %-*s %s %s\n" % (N, 8*3, hl, 8*3, hr, ll, lr)
+ N += 16
+ return result
+
+ def insta_creds(self, template=None, username=None, userpass=None, kerberos_state=None):
+
+ if template is None:
+ assert template is not None
+
+ if username is not None:
+ assert userpass is not None
+
+ if username is None:
+ assert userpass is None
+
+ username = template.get_username()
+ userpass = template.get_password()
+
+ if kerberos_state is None:
+ kerberos_state = template.get_kerberos_state()
+
+ # get a copy of the global creds or a the passed in creds
+ c = Credentials()
+ c.set_username(username)
+ c.set_password(userpass)
+ c.set_domain(template.get_domain())
+ c.set_realm(template.get_realm())
+ c.set_workstation(template.get_workstation())
+ c.set_gensec_features(c.get_gensec_features()
+ | gensec.FEATURE_SEAL)
+ c.set_kerberos_state(kerberos_state)
+ return c
+
+
+
# These functions didn't exist before Python2.7:
- if not getattr(unittest.TestCase, "skipTest", None):
+ if sys.version_info < (2, 7):
+ import warnings
+
def skipTest(self, reason):
raise SkipTest(reason)
- if not getattr(unittest.TestCase, "assertIs", None):
- def assertIs(self, a, b):
- self.assertTrue(a is b)
+ def assertIn(self, member, container, msg=None):
+ self.assertTrue(member in container, msg)
+
+ def assertIs(self, a, b, msg=None):
+ self.assertTrue(a is b, msg)
- if not getattr(unittest.TestCase, "assertIsNot", None):
- def assertIsNot(self, a, b):
- self.assertTrue(a is not b)
+ def assertIsNot(self, a, b, msg=None):
+ self.assertTrue(a is not b, msg)
- if not getattr(unittest.TestCase, "assertIsInstance", None):
- def assertIsInstance(self, a, b):
- self.assertTrue(isinstance(a, b))
+ def assertIsNotNone(self, a, msg=None):
+ self.assertTrue(a is not None)
+
+ def assertIsInstance(self, a, b, msg=None):
+ self.assertTrue(isinstance(a, b), msg)
+
+ def assertIsNone(self, a, msg=None):
+ self.assertTrue(a is None, msg)
+
+ def assertGreater(self, a, b, msg=None):
+ self.assertTrue(a > b, msg)
+
+ def assertGreaterEqual(self, a, b, msg=None):
+ self.assertTrue(a >= b, msg)
+
+ def assertLess(self, a, b, msg=None):
+ self.assertTrue(a < b, msg)
+
+ def assertLessEqual(self, a, b, msg=None):
+ self.assertTrue(a <= b, msg)
- if not getattr(unittest.TestCase, "addCleanup", None):
def addCleanup(self, fn, *args, **kwargs):
self._cleanups = getattr(self, "_cleanups", []) + [
(fn, args, kwargs)]
- def tearDown(self):
- super(TestCase, self).tearDown()
- for (fn, args, kwargs) in reversed(getattr(self, "_cleanups", [])):
- fn(*args, **kwargs)
+ def assertRegexpMatches(self, text, regex, msg=None):
+ # PY3 note: Python 3 will never see this, but we use
+ # text_type for the benefit of linters.
+ if isinstance(regex, (str, text_type)):
+ regex = re.compile(regex)
+ if not regex.search(text):
+ self.fail(msg)
+
+ def _addSkip(self, result, reason):
+ addSkip = getattr(result, 'addSkip', None)
+ if addSkip is not None:
+ addSkip(self, reason)
+ else:
+ warnings.warn("TestResult has no addSkip method, skips not reported",
+ RuntimeWarning, 2)
+ result.addSuccess(self)
+
+ def run(self, result=None):
+ if result is None: result = self.defaultTestResult()
+ result.startTest(self)
+ testMethod = getattr(self, self._testMethodName)
+ try:
+ try:
+ self.setUp()
+ except SkipTest as e:
+ self._addSkip(result, str(e))
+ return
+ except KeyboardInterrupt:
+ raise
+ except:
+ result.addError(self, self._exc_info())
+ return
+
+ ok = False
+ try:
+ testMethod()
+ ok = True
+ except SkipTest as e:
+ self._addSkip(result, str(e))
+ return
+ except self.failureException:
+ result.addFailure(self, self._exc_info())
+ except KeyboardInterrupt:
+ raise
+ except:
+ result.addError(self, self._exc_info())
+
+ try:
+ self.tearDown()
+ except SkipTest as e:
+ self._addSkip(result, str(e))
+ except KeyboardInterrupt:
+ raise
+ except:
+ result.addError(self, self._exc_info())
+ ok = False
+
+ for (fn, args, kwargs) in reversed(getattr(self, "_cleanups", [])):
+ fn(*args, **kwargs)
+ if ok: result.addSuccess(self)
+ finally:
+ result.stopTest(self)
+
+ def assertStringsEqual(self, a, b, msg=None, strip=False):
+ """Assert equality between two strings and highlight any differences.
+ If strip is true, leading and trailing whitespace is ignored."""
+ if strip:
+ a = a.strip()
+ b = b.strip()
+
+ if a != b:
+ sys.stderr.write("The strings differ %s(lengths %d vs %d); "
+ "a diff follows\n"
+ % ('when stripped ' if strip else '',
+ len(a), len(b),
+ ))
+
+ from difflib import unified_diff
+ diff = unified_diff(a.splitlines(True),
+ b.splitlines(True),
+ 'a', 'b')
+ for line in diff:
+ sys.stderr.write(line)
+
+ self.fail(msg)
class LdbTestCase(TestCase):
def setUp(self):
super(LdbTestCase, self).setUp()
- self.filename = os.tempnam()
+ self.tempfile = tempfile.NamedTemporaryFile(delete=False)
+ self.filename = self.tempfile.name
self.ldb = samba.Ldb(self.filename)
def set_modules(self, modules=[]):
return lp
-def env_get_var_value(var_name):
+def env_get_var_value(var_name, allow_missing=False):
"""Returns value for variable in os.environ
Function throws AssertionError if variable is defined.
Unit-test based python tests require certain input params
to be set in environment, otherwise they can't be run
"""
+ if allow_missing:
+ if var_name not in os.environ.keys():
+ return None
assert var_name in os.environ.keys(), "Please supply %s in environment" % var_name
return os.environ[var_name]
(S.stderr)
"""
- def __init__(self, returncode, cmd, stdout, stderr):
+ def __init__(self, returncode, cmd, stdout, stderr, msg=None):
self.returncode = returncode
self.cmd = cmd
self.stdout = stdout
self.stderr = stderr
+ self.msg = msg
def __str__(self):
- return "Command '%s'; exit status %d; stdout: '%s'; stderr: '%s'" % (self.cmd, self.returncode,
- self.stdout, self.stderr)
+ s = ("Command '%s'; exit status %d; stdout: '%s'; stderr: '%s'" %
+ (self.cmd, self.returncode, self.stdout, self.stderr))
+ if self.msg is not None:
+ s = "%s; message: %s" % (s, self.msg)
+
+ return s
-class BlackboxTestCase(TestCase):
+class BlackboxTestCase(TestCaseInTempDir):
"""Base test case for blackbox tests."""
def _make_cmdline(self, line):
line = " ".join(parts)
return line
- def check_run(self, line):
+ def check_run(self, line, msg=None):
+ self.check_exit_code(line, 0, msg=msg)
+
+ def check_exit_code(self, line, expected, msg=None):
line = self._make_cmdline(line)
- p = subprocess.Popen(line, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
- retcode = p.wait()
- if retcode:
- raise BlackboxProcessError(retcode, line, p.stdout.read(), p.stderr.read())
+ p = subprocess.Popen(line,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE,
+ shell=True)
+ stdoutdata, stderrdata = p.communicate()
+ retcode = p.returncode
+ if retcode != expected:
+ raise BlackboxProcessError(retcode,
+ line,
+ stdoutdata,
+ stderrdata,
+ msg)
def check_output(self, line):
line = self._make_cmdline(line)
p = subprocess.Popen(line, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True, close_fds=True)
- retcode = p.wait()
+ stdoutdata, stderrdata = p.communicate()
+ retcode = p.returncode
if retcode:
- raise BlackboxProcessError(retcode, line, p.stdout.read(), p.stderr.read())
- return p.stdout.read()
+ raise BlackboxProcessError(retcode, line, stdoutdata, stderrdata)
+ return stdoutdata
def connect_samdb(samdb_url, lp=None, session_info=None, credentials=None,
return connect_samdb(samdb_url, credentials=creds, lp=lp)
-def delete_force(samdb, dn):
+def delete_force(samdb, dn, **kwargs):
try:
- samdb.delete(dn)
- except ldb.LdbError, (num, errstr):
+ samdb.delete(dn, **kwargs)
+ except ldb.LdbError as error:
+ (num, errstr) = error.args
assert num == ldb.ERR_NO_SUCH_OBJECT, "ldb.delete() failed: %s" % errstr
+
+def create_test_ou(samdb, name):
+ """Creates a unique OU for the test"""
+
+ # Add some randomness to the test OU. Replication between the testenvs is
+ # constantly happening in the background. Deletion of the last test's
+ # objects can be slow to replicate out. So the OU created by a previous
+ # testenv may still exist at the point that tests start on another testenv.
+ rand = randint(1, 10000000)
+ dn = ldb.Dn(samdb, "OU=%s%d,%s" % (name, rand, samdb.get_default_basedn()))
+ samdb.add({"dn": dn, "objectclass": "organizationalUnit"})
+ return dn