testtools: Update to new upstream revision.
[nivanova/samba-autobuild/.git] / lib / testtools / testtools / matchers.py
index 4725265f9854eb181a724aa3fc1253326a22b0f6..3279306650e57fda9471de8c5bc727998fa6b6e8 100644 (file)
@@ -16,10 +16,14 @@ __all__ = [
     'AllMatch',
     'Annotate',
     'Contains',
+    'DirExists',
     'DocTestMatches',
     'EndsWith',
     'Equals',
+    'FileContains',
+    'FileExists',
     'GreaterThan',
+    'HasPermissions',
     'Is',
     'IsInstance',
     'KeysEqual',
@@ -28,21 +32,27 @@ __all__ = [
     'MatchesAny',
     'MatchesException',
     'MatchesListwise',
+    'MatchesPredicate',
     'MatchesRegex',
     'MatchesSetwise',
     'MatchesStructure',
     'NotEquals',
     'Not',
+    'PathExists',
     'Raises',
     'raises',
+    'SamePath',
     'StartsWith',
+    'TarballContains',
     ]
 
 import doctest
 import operator
 from pprint import pformat
 import re
+import os
 import sys
+import tarfile
 import types
 
 from testtools.compat import (
@@ -205,25 +215,25 @@ class _NonManglingOutputChecker(doctest.OutputChecker):
     """Doctest checker that works with unicode rather than mangling strings
 
     This is needed because current Python versions have tried to fix string
-    encoding related problems, but regressed the default behaviour with unicode
-    inputs in the process.
+    encoding related problems, but regressed the default behaviour with
+    unicode inputs in the process.
 
-    In Python 2.6 and 2.7 `OutputChecker.output_difference` is was changed to
-    return a bytestring encoded as per `sys.stdout.encoding`, or utf-8 if that
-    can't be determined. Worse, that encoding process happens in the innocent
-    looking `_indent` global function. Because the `DocTestMismatch.describe`
-    result may well not be destined for printing to stdout, this is no good
-    for us. To get a unicode return as before, the method is monkey patched if
-    `doctest._encoding` exists.
+    In Python 2.6 and 2.7 ``OutputChecker.output_difference`` is was changed
+    to return a bytestring encoded as per ``sys.stdout.encoding``, or utf-8 if
+    that can't be determined. Worse, that encoding process happens in the
+    innocent looking `_indent` global function. Because the
+    `DocTestMismatch.describe` result may well not be destined for printing to
+    stdout, this is no good for us. To get a unicode return as before, the
+    method is monkey patched if ``doctest._encoding`` exists.
 
     Python 3 has a different problem. For some reason both inputs are encoded
     to ascii with 'backslashreplace', making an escaped string matches its
-    unescaped form. Overriding the offending `OutputChecker._toAscii` method
+    unescaped form. Overriding the offending ``OutputChecker._toAscii`` method
     is sufficient to revert this.
     """
 
     def _toAscii(self, s):
-        """Return `s` unchanged rather than mangling it to ascii"""
+        """Return ``s`` unchanged rather than mangling it to ascii"""
         return s
 
     # Only do this overriding hackery if doctest has a broken _input function
@@ -232,7 +242,7 @@ class _NonManglingOutputChecker(doctest.OutputChecker):
         __f = doctest.OutputChecker.output_difference.im_func
         __g = dict(__f.func_globals)
         def _indent(s, indent=4, _pattern=re.compile("^(?!$)", re.MULTILINE)):
-            """Prepend non-empty lines in `s` with `indent` number of spaces"""
+            """Prepend non-empty lines in ``s`` with ``indent`` number of spaces"""
             return _pattern.sub(indent*" ", s)
         __g["_indent"] = _indent
         output_difference = __F(__f.func_code, __g, "output_difference")
@@ -385,6 +395,39 @@ class _BinaryMismatch(Mismatch):
             return "%s %s %s" % (left, self._mismatch_string, right)
 
 
+class MatchesPredicate(Matcher):
+    """Match if a given function returns True.
+
+    It is reasonably common to want to make a very simple matcher based on a
+    function that you already have that returns True or False given a single
+    argument (i.e. a predicate function).  This matcher makes it very easy to
+    do so. e.g.::
+
+      IsEven = MatchesPredicate(lambda x: x % 2 == 0, '%s is not even')
+      self.assertThat(4, IsEven)
+    """
+
+    def __init__(self, predicate, message):
+        """Create a ``MatchesPredicate`` matcher.
+
+        :param predicate: A function that takes a single argument and returns
+            a value that will be interpreted as a boolean.
+        :param message: A message to describe a mismatch.  It will be formatted
+            with '%' and be given whatever was passed to ``match()``. Thus, it
+            needs to contain exactly one thing like '%s', '%d' or '%f'.
+        """
+        self.predicate = predicate
+        self.message = message
+
+    def __str__(self):
+        return '%s(%r, %r)' % (
+            self.__class__.__name__, self.predicate, self.message)
+
+    def match(self, x):
+        if not self.predicate(x):
+            return Mismatch(self.message % x)
+
+
 class Equals(_BinaryComparison):
     """Matches if the items are equal."""
 
@@ -483,8 +526,16 @@ class MatchesAny(object):
 class MatchesAll(object):
     """Matches if all of the matchers it is created with match."""
 
-    def __init__(self, *matchers):
+    def __init__(self, *matchers, **options):
+        """Construct a MatchesAll matcher.
+
+        Just list the component matchers as arguments in the ``*args``
+        style. If you want only the first mismatch to be reported, past in
+        first_only=True as a keyword argument. By default, all mismatches are
+        reported.
+        """
         self.matchers = matchers
+        self.first_only = options.get('first_only', False)
 
     def __str__(self):
         return 'MatchesAll(%s)' % ', '.join(map(str, self.matchers))
@@ -494,6 +545,8 @@ class MatchesAll(object):
         for matcher in self.matchers:
             mismatch = matcher.match(matchee)
             if mismatch is not None:
+                if self.first_only:
+                    return mismatch
                 results.append(mismatch)
         if results:
             return MismatchesAll(results)
@@ -784,10 +837,20 @@ class MatchesListwise(object):
     1 != 2
     2 != 1
     ]
+    >>> matcher = MatchesListwise([Equals(1), Equals(2)], first_only=True)
+    >>> print (matcher.match([3, 4]).describe())
+    1 != 3
     """
 
-    def __init__(self, matchers):
+    def __init__(self, matchers, first_only=False):
+        """Construct a MatchesListwise matcher.
+
+        :param matchers: A list of matcher that the matched values must match.
+        :param first_only: If True, then only report the first mismatch,
+            otherwise report all of them. Defaults to False.
+        """
         self.matchers = matchers
+        self.first_only = first_only
 
     def match(self, values):
         mismatches = []
@@ -798,6 +861,8 @@ class MatchesListwise(object):
         for matcher, value in zip(self.matchers, values):
             mismatch = matcher.match(value)
             if mismatch:
+                if self.first_only:
+                    return mismatch
                 mismatches.append(mismatch)
         if mismatches:
             return MismatchesAll(mismatches)
@@ -1054,6 +1119,166 @@ class AllMatch(object):
             return MismatchesAll(mismatches)
 
 
+def PathExists():
+    """Matches if the given path exists.
+
+    Use like this::
+
+      assertThat('/some/path', PathExists())
+    """
+    return MatchesPredicate(os.path.exists, "%s does not exist.")
+
+
+def DirExists():
+    """Matches if the path exists and is a directory."""
+    return MatchesAll(
+        PathExists(),
+        MatchesPredicate(os.path.isdir, "%s is not a directory."),
+        first_only=True)
+
+
+def FileExists():
+    """Matches if the given path exists and is a file."""
+    return MatchesAll(
+        PathExists(),
+        MatchesPredicate(os.path.isfile, "%s is not a file."),
+        first_only=True)
+
+
+class DirContains(Matcher):
+    """Matches if the given directory contains files with the given names.
+
+    That is, is the directory listing exactly equal to the given files?
+    """
+
+    def __init__(self, filenames=None, matcher=None):
+        """Construct a ``DirContains`` matcher.
+
+        Can be used in a basic mode where the whole directory listing is
+        matched against an expected directory listing (by passing
+        ``filenames``).  Can also be used in a more advanced way where the
+        whole directory listing is matched against an arbitrary matcher (by
+        passing ``matcher`` instead).
+
+        :param filenames: If specified, match the sorted directory listing
+            against this list of filenames, sorted.
+        :param matcher: If specified, match the sorted directory listing
+            against this matcher.
+        """
+        if filenames == matcher == None:
+            raise AssertionError(
+                "Must provide one of `filenames` or `matcher`.")
+        if None not in (filenames, matcher):
+            raise AssertionError(
+                "Must provide either `filenames` or `matcher`, not both.")
+        if filenames is None:
+            self.matcher = matcher
+        else:
+            self.matcher = Equals(sorted(filenames))
+
+    def match(self, path):
+        mismatch = DirExists().match(path)
+        if mismatch is not None:
+            return mismatch
+        return self.matcher.match(sorted(os.listdir(path)))
+
+
+class FileContains(Matcher):
+    """Matches if the given file has the specified contents."""
+
+    def __init__(self, contents=None, matcher=None):
+        """Construct a ``FileContains`` matcher.
+
+        Can be used in a basic mode where the file contents are compared for
+        equality against the expected file contents (by passing ``contents``).
+        Can also be used in a more advanced way where the file contents are
+        matched against an arbitrary matcher (by passing ``matcher`` instead).
+
+        :param contents: If specified, match the contents of the file with
+            these contents.
+        :param matcher: If specified, match the contents of the file against
+            this matcher.
+        """
+        if contents == matcher == None:
+            raise AssertionError(
+                "Must provide one of `contents` or `matcher`.")
+        if None not in (contents, matcher):
+            raise AssertionError(
+                "Must provide either `contents` or `matcher`, not both.")
+        if matcher is None:
+            self.matcher = Equals(contents)
+        else:
+            self.matcher = matcher
+
+    def match(self, path):
+        mismatch = PathExists().match(path)
+        if mismatch is not None:
+            return mismatch
+        f = open(path)
+        try:
+            actual_contents = f.read()
+            return self.matcher.match(actual_contents)
+        finally:
+            f.close()
+
+    def __str__(self):
+        return "File at path exists and contains %s" % self.contents
+
+
+class TarballContains(Matcher):
+    """Matches if the given tarball contains the given paths.
+
+    Uses TarFile.getnames() to get the paths out of the tarball.
+    """
+
+    def __init__(self, paths):
+        super(TarballContains, self).__init__()
+        self.paths = paths
+
+    def match(self, tarball_path):
+        tarball = tarfile.open(tarball_path)
+        try:
+            return Equals(sorted(self.paths)).match(sorted(tarball.getnames()))
+        finally:
+            tarball.close()
+
+
+class SamePath(Matcher):
+    """Matches if two paths are the same.
+
+    That is, the paths are equal, or they point to the same file but in
+    different ways.  The paths do not have to exist.
+    """
+
+    def __init__(self, path):
+        super(SamePath, self).__init__()
+        self.path = path
+
+    def match(self, other_path):
+        f = lambda x: os.path.abspath(os.path.realpath(x))
+        return Equals(f(self.path)).match(f(other_path))
+
+
+class HasPermissions(Matcher):
+    """Matches if a file has the given permissions.
+
+    Permissions are specified and matched as a four-digit octal string.
+    """
+
+    def __init__(self, octal_permissions):
+        """Construct a HasPermissions matcher.
+
+        :param octal_permissions: A four digit octal string, representing the
+            intended access permissions. e.g. '0775' for rwxrwxr-x.
+        """
+        super(HasPermissions, self).__init__()
+        self.octal_permissions = octal_permissions
+
+    def match(self, filename):
+        permissions = oct(os.stat(filename).st_mode)[-4:]
+        return Equals(self.octal_permissions).match(permissions)
+
+
 # Signal that this is part of the testing framework, and that code from this
 # should not normally appear in tracebacks.
 __unittest = True