testtools: Merge in new upstream.
[nivanova/samba-autobuild/.git] / lib / testtools / testtools / matchers.py
index 6a4c82a2fe4f7cd32c4949ec8b421f4451a461ec..06b348c6d985131c0277b82d7d178904245e6c65 100644 (file)
@@ -1,4 +1,4 @@
-# Copyright (c) 2009 Jonathan M. Lange. See LICENSE for details.
+# Copyright (c) 2009-2010 Jonathan M. Lange. See LICENSE for details.
 
 """Matchers, a way to express complex assertions outside the testcase.
 
@@ -19,12 +19,20 @@ __all__ = [
     'LessThan',
     'MatchesAll',
     'MatchesAny',
+    'MatchesException',
     'NotEquals',
     'Not',
+    'Raises',
+    'raises',
+    'StartsWith',
     ]
 
 import doctest
 import operator
+from pprint import pformat
+import sys
+
+from testtools.compat import classtypes, _error_repr, isbaseexception
 
 
 class Matcher(object):
@@ -100,6 +108,10 @@ class Mismatch(object):
         """
         return getattr(self, '_details', {})
 
+    def __repr__(self):
+        return  "<testtools.matchers.Mismatch object at %x attributes=%r>" % (
+            id(self), self.__dict__)
+
 
 class DocTestMatches(object):
     """See if a string matches a doctest example."""
@@ -151,6 +163,39 @@ class DocTestMismatch(Mismatch):
         return self.matcher._describe_difference(self.with_nl)
 
 
+class DoesNotStartWith(Mismatch):
+
+    def __init__(self, matchee, expected):
+        """Create a DoesNotStartWith Mismatch.
+
+        :param matchee: the string that did not match.
+        :param expected: the string that `matchee` was expected to start
+            with.
+        """
+        self.matchee = matchee
+        self.expected = expected
+
+    def describe(self):
+        return "'%s' does not start with '%s'." % (
+            self.matchee, self.expected)
+
+
+class DoesNotEndWith(Mismatch):
+
+    def __init__(self, matchee, expected):
+        """Create a DoesNotEndWith Mismatch.
+
+        :param matchee: the string that did not match.
+        :param expected: the string that `matchee` was expected to end with.
+        """
+        self.matchee = matchee
+        self.expected = expected
+
+    def describe(self):
+        return "'%s' does not end with '%s'." % (
+            self.matchee, self.expected)
+
+
 class _BinaryComparison(object):
     """Matcher that compares an object to another object."""
 
@@ -178,7 +223,14 @@ class _BinaryMismatch(Mismatch):
         self.other = other
 
     def describe(self):
-        return "%r %s %r" % (self.expected, self._mismatch_string, self.other)
+        left = repr(self.expected)
+        right = repr(self.other)
+        if len(left) + len(right) > 70:
+            return "%s:\nreference = %s\nactual = %s\n" % (
+                self._mismatch_string, pformat(self.expected),
+                pformat(self.other))
+        else:
+            return "%s %s %s" % (left, self._mismatch_string,right)
 
 
 class Equals(_BinaryComparison):
@@ -264,7 +316,7 @@ class MismatchesAll(Mismatch):
         descriptions = ["Differences: ["]
         for mismatch in self.mismatches:
             descriptions.append(mismatch.describe())
-        descriptions.append("]\n")
+        descriptions.append("]")
         return '\n'.join(descriptions)
 
 
@@ -296,6 +348,105 @@ class MatchedUnexpectedly(Mismatch):
         return "%r matches %s" % (self.other, self.matcher)
 
 
+class MatchesException(Matcher):
+    """Match an exc_info tuple against an exception instance or type."""
+
+    def __init__(self, exception):
+        """Create a MatchesException that will match exc_info's for exception.
+
+        :param exception: Either an exception instance or type.
+            If an instance is given, the type and arguments of the exception
+            are checked. If a type is given only the type of the exception is
+            checked.
+        """
+        Matcher.__init__(self)
+        self.expected = exception
+        self._is_instance = type(self.expected) not in classtypes()
+
+    def match(self, other):
+        if type(other) != tuple:
+            return Mismatch('%r is not an exc_info tuple' % other)
+        expected_class = self.expected
+        if self._is_instance:
+            expected_class = expected_class.__class__
+        if not issubclass(other[0], expected_class):
+            return Mismatch('%r is not a %r' % (other[0], expected_class))
+        if self._is_instance and other[1].args != self.expected.args:
+            return Mismatch('%s has different arguments to %s.' % (
+                _error_repr(other[1]), _error_repr(self.expected)))
+
+    def __str__(self):
+        if self._is_instance:
+            return "MatchesException(%s)" % _error_repr(self.expected)
+        return "MatchesException(%s)" % repr(self.expected)
+
+
+class StartsWith(Matcher):
+    """Checks whether one string starts with another."""
+
+    def __init__(self, expected):
+        """Create a StartsWith Matcher.
+
+        :param expected: the string that matchees should start with.
+        """
+        self.expected = expected
+
+    def __str__(self):
+        return "Starts with '%s'." % self.expected
+
+    def match(self, matchee):
+        if not matchee.startswith(self.expected):
+            return DoesNotStartWith(matchee, self.expected)
+        return None
+
+
+class EndsWith(Matcher):
+    """Checks whether one string starts with another."""
+
+    def __init__(self, expected):
+        """Create a EndsWith Matcher.
+
+        :param expected: the string that matchees should end with.
+        """
+        self.expected = expected
+
+    def __str__(self):
+        return "Ends with '%s'." % self.expected
+
+    def match(self, matchee):
+        if not matchee.endswith(self.expected):
+            return DoesNotEndWith(matchee, self.expected)
+        return None
+
+
+class KeysEqual(Matcher):
+    """Checks whether a dict has particular keys."""
+
+    def __init__(self, *expected):
+        """Create a `KeysEqual` Matcher.
+
+        :param *expected: The keys the dict is expected to have.  If a dict,
+            then we use the keys of that dict, if a collection, we assume it
+            is a collection of expected keys.
+        """
+        try:
+            self.expected = expected.keys()
+        except AttributeError:
+            self.expected = list(expected)
+
+    def __str__(self):
+        return "KeysEqual(%s)" % ', '.join(map(repr, self.expected))
+
+    def match(self, matchee):
+        expected = sorted(self.expected)
+        matched = Equals(expected).match(sorted(matchee.keys()))
+        if matched:
+            return AnnotatedMismatch(
+                'Keys not equal',
+                _BinaryMismatch(expected, 'does not match', matchee))
+        return None
+
+
 class Annotate(object):
     """Annotates a matcher with a descriptive string.
 
@@ -324,3 +475,56 @@ class AnnotatedMismatch(Mismatch):
 
     def describe(self):
         return '%s: %s' % (self.mismatch.describe(), self.annotation)
+
+
+class Raises(Matcher):
+    """Match if the matchee raises an exception when called.
+
+    Exceptions which are not subclasses of Exception propogate out of the
+    Raises.match call unless they are explicitly matched.
+    """
+
+    def __init__(self, exception_matcher=None):
+        """Create a Raises matcher.
+
+        :param exception_matcher: Optional validator for the exception raised
+            by matchee. If supplied the exc_info tuple for the exception raised
+            is passed into that matcher. If no exception_matcher is supplied
+            then the simple fact of raising an exception is considered enough
+            to match on.
+        """
+        self.exception_matcher = exception_matcher
+
+    def match(self, matchee):
+        try:
+            result = matchee()
+            return Mismatch('%r returned %r' % (matchee, result))
+        # Catch all exceptions: Raises() should be able to match a
+        # KeyboardInterrupt or SystemExit.
+        except:
+            if self.exception_matcher:
+                mismatch = self.exception_matcher.match(sys.exc_info())
+                if not mismatch:
+                    return
+            else:
+                mismatch = None
+            # The exception did not match, or no explicit matching logic was
+            # performed. If the exception is a non-user exception (that is, not
+            # a subclass of Exception on Python 2.5+) then propogate it.
+            if isbaseexception(sys.exc_info()[1]):
+                raise
+            return mismatch
+
+    def __str__(self):
+        return 'Raises()'
+
+
+def raises(exception):
+    """Make a matcher that checks that a callable raises an exception.
+
+    This is a convenience function, exactly equivalent to::
+        return Raises(MatchesException(exception))
+
+    See `Raises` and `MatchesException` for more information.
+    """
+    return Raises(MatchesException(exception))