testtools: Update to latest upstream snapshot.
[nivanova/samba-autobuild/.git] / lib / testtools / testtools / matchers.py
index 06b348c6d985131c0277b82d7d178904245e6c65..6ee33f0fd828a7a889fc0bd8d8cc25fb6c484f0b 100644 (file)
@@ -1,4 +1,4 @@
-# Copyright (c) 2009-2010 Jonathan M. Lange. See LICENSE for details.
+# Copyright (c) 2009-2011 testtools developers. See LICENSE for details.
 
 """Matchers, a way to express complex assertions outside the testcase.
 
@@ -12,14 +12,25 @@ $ python -c 'import testtools.matchers; print testtools.matchers.__all__'
 
 __metaclass__ = type
 __all__ = [
+    'AfterPreprocessing',
+    'AllMatch',
     'Annotate',
+    'Contains',
     'DocTestMatches',
+    'EndsWith',
     'Equals',
+    'GreaterThan',
     'Is',
+    'IsInstance',
+    'KeysEqual',
     'LessThan',
     'MatchesAll',
     'MatchesAny',
     'MatchesException',
+    'MatchesListwise',
+    'MatchesRegex',
+    'MatchesSetwise',
+    'MatchesStructure',
     'NotEquals',
     'Not',
     'Raises',
@@ -30,9 +41,16 @@ __all__ = [
 import doctest
 import operator
 from pprint import pformat
+import re
 import sys
+import types
 
-from testtools.compat import classtypes, _error_repr, isbaseexception
+from testtools.compat import (
+    classtypes,
+    _error_repr,
+    isbaseexception,
+    istext,
+    )
 
 
 class Matcher(object):
@@ -113,6 +131,69 @@ class Mismatch(object):
             id(self), self.__dict__)
 
 
+class MismatchDecorator(object):
+    """Decorate a ``Mismatch``.
+
+    Forwards all messages to the original mismatch object.  Probably the best
+    way to use this is inherit from this class and then provide your own
+    custom decoration logic.
+    """
+
+    def __init__(self, original):
+        """Construct a `MismatchDecorator`.
+
+        :param original: A `Mismatch` object to decorate.
+        """
+        self.original = original
+
+    def __repr__(self):
+        return '<testtools.matchers.MismatchDecorator(%r)>' % (self.original,)
+
+    def describe(self):
+        return self.original.describe()
+
+    def get_details(self):
+        return self.original.get_details()
+
+
+class _NonManglingOutputChecker(doctest.OutputChecker):
+    """Doctest checker that works with unicode rather than mangling strings
+
+    This is needed because current Python versions have tried to fix string
+    encoding related problems, but regressed the default behaviour with unicode
+    inputs in the process.
+
+    In Python 2.6 and 2.7 `OutputChecker.output_difference` is was changed to
+    return a bytestring encoded as per `sys.stdout.encoding`, or utf-8 if that
+    can't be determined. Worse, that encoding process happens in the innocent
+    looking `_indent` global function. Because the `DocTestMismatch.describe`
+    result may well not be destined for printing to stdout, this is no good
+    for us. To get a unicode return as before, the method is monkey patched if
+    `doctest._encoding` exists.
+
+    Python 3 has a different problem. For some reason both inputs are encoded
+    to ascii with 'backslashreplace', making an escaped string matches its
+    unescaped form. Overriding the offending `OutputChecker._toAscii` method
+    is sufficient to revert this.
+    """
+
+    def _toAscii(self, s):
+        """Return `s` unchanged rather than mangling it to ascii"""
+        return s
+
+    # Only do this overriding hackery if doctest has a broken _input function
+    if getattr(doctest, "_encoding", None) is not None:
+        from types import FunctionType as __F
+        __f = doctest.OutputChecker.output_difference.im_func
+        __g = dict(__f.func_globals)
+        def _indent(s, indent=4, _pattern=re.compile("^(?!$)", re.MULTILINE)):
+            """Prepend non-empty lines in `s` with `indent` number of spaces"""
+            return _pattern.sub(indent*" ", s)
+        __g["_indent"] = _indent
+        output_difference = __F(__f.func_code, __g, "output_difference")
+        del __F, __f, __g, _indent
+
+
 class DocTestMatches(object):
     """See if a string matches a doctest example."""
 
@@ -127,7 +208,7 @@ class DocTestMatches(object):
             example += '\n'
         self.want = example # required variable name by doctest.
         self.flags = flags
-        self._checker = doctest.OutputChecker()
+        self._checker = _NonManglingOutputChecker()
 
     def __str__(self):
         if self.flags:
@@ -137,7 +218,7 @@ class DocTestMatches(object):
         return 'DocTestMatches(%r%s)' % (self.want, flagstr)
 
     def _with_nl(self, actual):
-        result = str(actual)
+        result = self.want.__class__(actual)
         if not result.endswith('\n'):
             result += '\n'
         return result
@@ -163,14 +244,28 @@ class DocTestMismatch(Mismatch):
         return self.matcher._describe_difference(self.with_nl)
 
 
+class DoesNotContain(Mismatch):
+
+    def __init__(self, matchee, needle):
+        """Create a DoesNotContain Mismatch.
+
+        :param matchee: the object that did not contain needle.
+        :param needle: the needle that 'matchee' was expected to contain.
+        """
+        self.matchee = matchee
+        self.needle = needle
+
+    def describe(self):
+        return "%r not in %r" % (self.needle, self.matchee)
+
+
 class DoesNotStartWith(Mismatch):
 
     def __init__(self, matchee, expected):
         """Create a DoesNotStartWith Mismatch.
 
         :param matchee: the string that did not match.
-        :param expected: the string that `matchee` was expected to start
-            with.
+        :param expected: the string that 'matchee' was expected to start with.
         """
         self.matchee = matchee
         self.expected = expected
@@ -186,7 +281,7 @@ class DoesNotEndWith(Mismatch):
         """Create a DoesNotEndWith Mismatch.
 
         :param matchee: the string that did not match.
-        :param expected: the string that `matchee` was expected to end with.
+        :param expected: the string that 'matchee' was expected to end with.
         """
         self.matchee = matchee
         self.expected = expected
@@ -222,13 +317,20 @@ class _BinaryMismatch(Mismatch):
         self._mismatch_string = mismatch_string
         self.other = other
 
+    def _format(self, thing):
+        # Blocks of text with newlines are formatted as triple-quote
+        # strings. Everything else is pretty-printed.
+        if istext(thing) and '\n' in thing:
+            return '"""\\\n%s"""' % (thing,)
+        return pformat(thing)
+
     def describe(self):
         left = repr(self.expected)
         right = repr(self.other)
         if len(left) + len(right) > 70:
             return "%s:\nreference = %s\nactual = %s\n" % (
-                self._mismatch_string, pformat(self.expected),
-                pformat(self.other))
+                self._mismatch_string, self._format(self.expected),
+                self._format(self.other))
         else:
             return "%s %s %s" % (left, self._mismatch_string,right)
 
@@ -243,8 +345,8 @@ class Equals(_BinaryComparison):
 class NotEquals(_BinaryComparison):
     """Matches if the items are not equal.
 
-    In most cases, this is equivalent to `Not(Equals(foo))`. The difference
-    only matters when testing `__ne__` implementations.
+    In most cases, this is equivalent to ``Not(Equals(foo))``. The difference
+    only matters when testing ``__ne__`` implementations.
     """
 
     comparator = operator.ne
@@ -258,11 +360,54 @@ class Is(_BinaryComparison):
     mismatch_string = 'is not'
 
 
+class IsInstance(object):
+    """Matcher that wraps isinstance."""
+
+    def __init__(self, *types):
+        self.types = tuple(types)
+
+    def __str__(self):
+        return "%s(%s)" % (self.__class__.__name__,
+                ', '.join(type.__name__ for type in self.types))
+
+    def match(self, other):
+        if isinstance(other, self.types):
+            return None
+        return NotAnInstance(other, self.types)
+
+
+class NotAnInstance(Mismatch):
+
+    def __init__(self, matchee, types):
+        """Create a NotAnInstance Mismatch.
+
+        :param matchee: the thing which is not an instance of any of types.
+        :param types: A tuple of the types which were expected.
+        """
+        self.matchee = matchee
+        self.types = types
+
+    def describe(self):
+        if len(self.types) == 1:
+            typestr = self.types[0].__name__
+        else:
+            typestr = 'any of (%s)' % ', '.join(type.__name__ for type in
+                    self.types)
+        return "'%s' is not an instance of %s" % (self.matchee, typestr)
+
+
 class LessThan(_BinaryComparison):
     """Matches if the item is less than the matchers reference object."""
 
     comparator = operator.__lt__
-    mismatch_string = 'is >='
+    mismatch_string = 'is not >'
+
+
+class GreaterThan(_BinaryComparison):
+    """Matches if the item is greater than the matchers reference object."""
+
+    comparator = operator.__gt__
+    mismatch_string = 'is not <'
 
 
 class MatchesAny(object):
@@ -351,17 +496,26 @@ class MatchedUnexpectedly(Mismatch):
 class MatchesException(Matcher):
     """Match an exc_info tuple against an exception instance or type."""
 
-    def __init__(self, exception):
+    def __init__(self, exception, value_re=None):
         """Create a MatchesException that will match exc_info's for exception.
 
         :param exception: Either an exception instance or type.
             If an instance is given, the type and arguments of the exception
             are checked. If a type is given only the type of the exception is
-            checked.
+            checked. If a tuple is given, then as with isinstance, any of the
+            types in the tuple matching is sufficient to match.
+        :param value_re: If 'exception' is a type, and the matchee exception
+            is of the right type, then match against this.  If value_re is a
+            string, then assume value_re is a regular expression and match
+            the str() of the exception against it.  Otherwise, assume value_re
+            is a matcher, and match the exception against it.
         """
         Matcher.__init__(self)
         self.expected = exception
-        self._is_instance = type(self.expected) not in classtypes()
+        if istext(value_re):
+            value_re = AfterPreproccessing(str, MatchesRegex(value_re), False)
+        self.value_re = value_re
+        self._is_instance = type(self.expected) not in classtypes() + (tuple,)
 
     def match(self, other):
         if type(other) != tuple:
@@ -371,9 +525,12 @@ class MatchesException(Matcher):
             expected_class = expected_class.__class__
         if not issubclass(other[0], expected_class):
             return Mismatch('%r is not a %r' % (other[0], expected_class))
-        if self._is_instance and other[1].args != self.expected.args:
-            return Mismatch('%s has different arguments to %s.' % (
-                _error_repr(other[1]), _error_repr(self.expected)))
+        if self._is_instance:
+            if other[1].args != self.expected.args:
+                return Mismatch('%s has different arguments to %s.' % (
+                        _error_repr(other[1]), _error_repr(self.expected)))
+        elif self.value_re is not None:
+            return self.value_re.match(other[1])
 
     def __str__(self):
         if self._is_instance:
@@ -381,6 +538,29 @@ class MatchesException(Matcher):
         return "MatchesException(%s)" % repr(self.expected)
 
 
+class Contains(Matcher):
+    """Checks whether something is container in another thing."""
+
+    def __init__(self, needle):
+        """Create a Contains Matcher.
+
+        :param needle: the thing that needs to be contained by matchees.
+        """
+        self.needle = needle
+
+    def __str__(self):
+        return "Contains(%r)" % (self.needle,)
+
+    def match(self, matchee):
+        try:
+            if self.needle not in matchee:
+                return DoesNotContain(matchee, self.needle)
+        except TypeError:
+            # e.g. 1 in 2 will raise TypeError
+            return DoesNotContain(matchee, self.needle)
+        return None
+
+
 class StartsWith(Matcher):
     """Checks whether one string starts with another."""
 
@@ -425,7 +605,7 @@ class KeysEqual(Matcher):
     def __init__(self, *expected):
         """Create a `KeysEqual` Matcher.
 
-        :param *expected: The keys the dict is expected to have.  If a dict,
+        :param expected: The keys the dict is expected to have.  If a dict,
             then we use the keys of that dict, if a collection, we assume it
             is a collection of expected keys.
         """
@@ -457,6 +637,13 @@ class Annotate(object):
         self.annotation = annotation
         self.matcher = matcher
 
+    @classmethod
+    def if_message(cls, annotation, matcher):
+        """Annotate ``matcher`` only if ``annotation`` is non-empty."""
+        if not annotation:
+            return matcher
+        return cls(annotation, matcher)
+
     def __str__(self):
         return 'Annotate(%r, %s)' % (self.annotation, self.matcher)
 
@@ -466,15 +653,16 @@ class Annotate(object):
             return AnnotatedMismatch(self.annotation, mismatch)
 
 
-class AnnotatedMismatch(Mismatch):
+class AnnotatedMismatch(MismatchDecorator):
     """A mismatch annotated with a descriptive string."""
 
     def __init__(self, annotation, mismatch):
+        super(AnnotatedMismatch, self).__init__(mismatch)
         self.annotation = annotation
         self.mismatch = mismatch
 
     def describe(self):
-        return '%s: %s' % (self.mismatch.describe(), self.annotation)
+        return '%s: %s' % (self.original.describe(), self.annotation)
 
 
 class Raises(Matcher):
@@ -502,16 +690,19 @@ class Raises(Matcher):
         # Catch all exceptions: Raises() should be able to match a
         # KeyboardInterrupt or SystemExit.
         except:
+            exc_info = sys.exc_info()
             if self.exception_matcher:
-                mismatch = self.exception_matcher.match(sys.exc_info())
+                mismatch = self.exception_matcher.match(exc_info)
                 if not mismatch:
+                    del exc_info
                     return
             else:
                 mismatch = None
             # The exception did not match, or no explicit matching logic was
             # performed. If the exception is a non-user exception (that is, not
             # a subclass of Exception on Python 2.5+) then propogate it.
-            if isbaseexception(sys.exc_info()[1]):
+            if isbaseexception(exc_info[1]):
+                del exc_info
                 raise
             return mismatch
 
@@ -523,8 +714,292 @@ def raises(exception):
     """Make a matcher that checks that a callable raises an exception.
 
     This is a convenience function, exactly equivalent to::
+
         return Raises(MatchesException(exception))
 
     See `Raises` and `MatchesException` for more information.
     """
     return Raises(MatchesException(exception))
+
+
+class MatchesListwise(object):
+    """Matches if each matcher matches the corresponding value.
+
+    More easily explained by example than in words:
+
+    >>> MatchesListwise([Equals(1)]).match([1])
+    >>> MatchesListwise([Equals(1), Equals(2)]).match([1, 2])
+    >>> print (MatchesListwise([Equals(1), Equals(2)]).match([2, 1]).describe())
+    Differences: [
+    1 != 2
+    2 != 1
+    ]
+    """
+
+    def __init__(self, matchers):
+        self.matchers = matchers
+
+    def match(self, values):
+        mismatches = []
+        length_mismatch = Annotate(
+            "Length mismatch", Equals(len(self.matchers))).match(len(values))
+        if length_mismatch:
+            mismatches.append(length_mismatch)
+        for matcher, value in zip(self.matchers, values):
+            mismatch = matcher.match(value)
+            if mismatch:
+                mismatches.append(mismatch)
+        if mismatches:
+            return MismatchesAll(mismatches)
+
+
+class MatchesStructure(object):
+    """Matcher that matches an object structurally.
+
+    'Structurally' here means that attributes of the object being matched are
+    compared against given matchers.
+
+    `fromExample` allows the creation of a matcher from a prototype object and
+    then modified versions can be created with `update`.
+
+    `byEquality` creates a matcher in much the same way as the constructor,
+    except that the matcher for each of the attributes is assumed to be
+    `Equals`.
+
+    `byMatcher` creates a similar matcher to `byEquality`, but you get to pick
+    the matcher, rather than just using `Equals`.
+    """
+
+    def __init__(self, **kwargs):
+        """Construct a `MatchesStructure`.
+
+        :param kwargs: A mapping of attributes to matchers.
+        """
+        self.kws = kwargs
+
+    @classmethod
+    def byEquality(cls, **kwargs):
+        """Matches an object where the attributes equal the keyword values.
+
+        Similar to the constructor, except that the matcher is assumed to be
+        Equals.
+        """
+        return cls.byMatcher(Equals, **kwargs)
+
+    @classmethod
+    def byMatcher(cls, matcher, **kwargs):
+        """Matches an object where the attributes match the keyword values.
+
+        Similar to the constructor, except that the provided matcher is used
+        to match all of the values.
+        """
+        return cls(
+            **dict((name, matcher(value)) for name, value in kwargs.items()))
+
+    @classmethod
+    def fromExample(cls, example, *attributes):
+        kwargs = {}
+        for attr in attributes:
+            kwargs[attr] = Equals(getattr(example, attr))
+        return cls(**kwargs)
+
+    def update(self, **kws):
+        new_kws = self.kws.copy()
+        for attr, matcher in kws.items():
+            if matcher is None:
+                new_kws.pop(attr, None)
+            else:
+                new_kws[attr] = matcher
+        return type(self)(**new_kws)
+
+    def __str__(self):
+        kws = []
+        for attr, matcher in sorted(self.kws.items()):
+            kws.append("%s=%s" % (attr, matcher))
+        return "%s(%s)" % (self.__class__.__name__, ', '.join(kws))
+
+    def match(self, value):
+        matchers = []
+        values = []
+        for attr, matcher in sorted(self.kws.items()):
+            matchers.append(Annotate(attr, matcher))
+            values.append(getattr(value, attr))
+        return MatchesListwise(matchers).match(values)
+
+
+class MatchesRegex(object):
+    """Matches if the matchee is matched by a regular expression."""
+
+    def __init__(self, pattern, flags=0):
+        self.pattern = pattern
+        self.flags = flags
+
+    def __str__(self):
+        args = ['%r' % self.pattern]
+        flag_arg = []
+        # dir() sorts the attributes for us, so we don't need to do it again.
+        for flag in dir(re):
+            if len(flag) == 1:
+                if self.flags & getattr(re, flag):
+                    flag_arg.append('re.%s' % flag)
+        if flag_arg:
+            args.append('|'.join(flag_arg))
+        return '%s(%s)' % (self.__class__.__name__, ', '.join(args))
+
+    def match(self, value):
+        if not re.match(self.pattern, value, self.flags):
+            return Mismatch("%r does not match /%s/" % (
+                    value, self.pattern))
+
+
+class MatchesSetwise(object):
+    """Matches if all the matchers match elements of the value being matched.
+
+    That is, each element in the 'observed' set must match exactly one matcher
+    from the set of matchers, with no matchers left over.
+
+    The difference compared to `MatchesListwise` is that the order of the
+    matchings does not matter.
+    """
+
+    def __init__(self, *matchers):
+        self.matchers = matchers
+
+    def match(self, observed):
+        remaining_matchers = set(self.matchers)
+        not_matched = []
+        for value in observed:
+            for matcher in remaining_matchers:
+                if matcher.match(value) is None:
+                    remaining_matchers.remove(matcher)
+                    break
+            else:
+                not_matched.append(value)
+        if not_matched or remaining_matchers:
+            remaining_matchers = list(remaining_matchers)
+            # There are various cases that all should be reported somewhat
+            # differently.
+
+            # There are two trivial cases:
+            # 1) There are just some matchers left over.
+            # 2) There are just some values left over.
+
+            # Then there are three more interesting cases:
+            # 3) There are the same number of matchers and values left over.
+            # 4) There are more matchers left over than values.
+            # 5) There are more values left over than matchers.
+
+            if len(not_matched) == 0:
+                if len(remaining_matchers) > 1:
+                    msg = "There were %s matchers left over: " % (
+                        len(remaining_matchers),)
+                else:
+                    msg = "There was 1 matcher left over: "
+                msg += ', '.join(map(str, remaining_matchers))
+                return Mismatch(msg)
+            elif len(remaining_matchers) == 0:
+                if len(not_matched) > 1:
+                    return Mismatch(
+                        "There were %s values left over: %s" % (
+                            len(not_matched), not_matched))
+                else:
+                    return Mismatch(
+                        "There was 1 value left over: %s" % (
+                            not_matched, ))
+            else:
+                common_length = min(len(remaining_matchers), len(not_matched))
+                if common_length == 0:
+                    raise AssertionError("common_length can't be 0 here")
+                if common_length > 1:
+                    msg = "There were %s mismatches" % (common_length,)
+                else:
+                    msg = "There was 1 mismatch"
+                if len(remaining_matchers) > len(not_matched):
+                    extra_matchers = remaining_matchers[common_length:]
+                    msg += " and %s extra matcher" % (len(extra_matchers), )
+                    if len(extra_matchers) > 1:
+                        msg += "s"
+                    msg += ': ' + ', '.join(map(str, extra_matchers))
+                elif len(not_matched) > len(remaining_matchers):
+                    extra_values = not_matched[common_length:]
+                    msg += " and %s extra value" % (len(extra_values), )
+                    if len(extra_values) > 1:
+                        msg += "s"
+                    msg += ': ' + str(extra_values)
+                return Annotate(
+                    msg, MatchesListwise(remaining_matchers[:common_length])
+                    ).match(not_matched[:common_length])
+
+
+class AfterPreprocessing(object):
+    """Matches if the value matches after passing through a function.
+
+    This can be used to aid in creating trivial matchers as functions, for
+    example::
+
+      def PathHasFileContent(content):
+          def _read(path):
+              return open(path).read()
+          return AfterPreprocessing(_read, Equals(content))
+    """
+
+    def __init__(self, preprocessor, matcher, annotate=True):
+        """Create an AfterPreprocessing matcher.
+
+        :param preprocessor: A function called with the matchee before
+            matching.
+        :param matcher: What to match the preprocessed matchee against.
+        :param annotate: Whether or not to annotate the matcher with
+            something explaining how we transformed the matchee. Defaults
+            to True.
+        """
+        self.preprocessor = preprocessor
+        self.matcher = matcher
+        self.annotate = annotate
+
+    def _str_preprocessor(self):
+        if isinstance(self.preprocessor, types.FunctionType):
+            return '<function %s>' % self.preprocessor.__name__
+        return str(self.preprocessor)
+
+    def __str__(self):
+        return "AfterPreprocessing(%s, %s)" % (
+            self._str_preprocessor(), self.matcher)
+
+    def match(self, value):
+        after = self.preprocessor(value)
+        if self.annotate:
+            matcher = Annotate(
+                "after %s on %r" % (self._str_preprocessor(), value),
+                self.matcher)
+        else:
+            matcher = self.matcher
+        return matcher.match(after)
+
+# This is the old, deprecated. spelling of the name, kept for backwards
+# compatibility.
+AfterPreproccessing = AfterPreprocessing
+
+
+class AllMatch(object):
+    """Matches if all provided values match the given matcher."""
+
+    def __init__(self, matcher):
+        self.matcher = matcher
+
+    def __str__(self):
+        return 'AllMatch(%s)' % (self.matcher,)
+
+    def match(self, values):
+        mismatches = []
+        for value in values:
+            mismatch = self.matcher.match(value)
+            if mismatch:
+                mismatches.append(mismatch)
+        if mismatches:
+            return MismatchesAll(mismatches)
+
+
+# Signal that this is part of the testing framework, and that code from this
+# should not normally appear in tracebacks.
+__unittest = True