testtools: Import new upstream snapshot.
authorJelmer Vernooij <jelmer@samba.org>
Thu, 9 Dec 2010 13:51:17 +0000 (14:51 +0100)
committerJelmer Vernooij <jelmer@samba.org>
Fri, 10 Dec 2010 02:04:06 +0000 (03:04 +0100)
36 files changed:
lib/testtools/.testr.conf [new file with mode: 0644]
lib/testtools/HACKING
lib/testtools/MANIFEST.in
lib/testtools/MANUAL
lib/testtools/Makefile
lib/testtools/NEWS
lib/testtools/README
lib/testtools/setup.py
lib/testtools/testtools/__init__.py
lib/testtools/testtools/_spinner.py [new file with mode: 0644]
lib/testtools/testtools/compat.py
lib/testtools/testtools/content.py
lib/testtools/testtools/deferredruntest.py [new file with mode: 0644]
lib/testtools/testtools/helpers.py [new file with mode: 0644]
lib/testtools/testtools/matchers.py
lib/testtools/testtools/run.py
lib/testtools/testtools/runtest.py
lib/testtools/testtools/testcase.py
lib/testtools/testtools/testresult/doubles.py
lib/testtools/testtools/testresult/real.py
lib/testtools/testtools/tests/__init__.py
lib/testtools/testtools/tests/helpers.py
lib/testtools/testtools/tests/test_compat.py
lib/testtools/testtools/tests/test_content.py
lib/testtools/testtools/tests/test_content_type.py
lib/testtools/testtools/tests/test_deferredruntest.py [new file with mode: 0644]
lib/testtools/testtools/tests/test_fixturesupport.py [new file with mode: 0644]
lib/testtools/testtools/tests/test_helpers.py [new file with mode: 0644]
lib/testtools/testtools/tests/test_matchers.py
lib/testtools/testtools/tests/test_monkey.py
lib/testtools/testtools/tests/test_run.py [new file with mode: 0644]
lib/testtools/testtools/tests/test_runtest.py
lib/testtools/testtools/tests/test_spinner.py [new file with mode: 0644]
lib/testtools/testtools/tests/test_testresult.py
lib/testtools/testtools/tests/test_testsuite.py
lib/testtools/testtools/tests/test_testtools.py

diff --git a/lib/testtools/.testr.conf b/lib/testtools/.testr.conf
new file mode 100644 (file)
index 0000000..12d6685
--- /dev/null
@@ -0,0 +1,4 @@
+[DEFAULT]
+test_command=PYTHONPATH=. python -m subunit.run $LISTOPT $IDOPTION testtools.tests.test_suite
+test_id_option=--load-list $IDFILE
+test_list_option=--list
index 60b1a90a8c7d139052ce6b629d3be572994d6fab..cc1a88f15496afa6389a438ca66408797176f694 100644 (file)
@@ -111,30 +111,24 @@ permanently present at the top of the list.
 Release tasks
 -------------
 
-In no particular order:
-
-* Choose a version number.
-
-* Ensure __init__ has that version.
-
-* Add a version number to NEWS immediately below NEXT.
-
-* Possibly write a blurb into NEWS.
-
-* Replace any additional references to NEXT with the version being released.
-  (should be none).
-
-* Create a source distribution and upload to pypi ('make release').
-
-* Upload to Launchpad as well.
-
-* If a new series has been created (e.g. 0.10.0), make the series on Launchpad.
-
-* Merge or push the release branch to trunk.
-
-* Make a new milestone for the *next release*. We don't really know how we want
-  to handle these yet, so this is a suggestion not actual practice:
-
-  * during release we rename NEXT to $version.
-
-  * we call new milestones NEXT.
+ 1. Choose a version number, say X.Y.Z
+ 1. Branch from trunk to testtools-X.Y.Z
+ 1. In testtools-X.Y.Z, ensure __init__ has version X.Y.Z.
+ 1. Replace NEXT in NEWS with the version number X.Y.Z, adjusting the reST.
+ 1. Possibly write a blurb into NEWS.
+ 1. Replace any additional references to NEXT with the version being
+    released. (should be none).
+ 1. Commit the changes.
+ 1. Tag the release, bzr tag testtools-X.Y.Z
+ 1. Create a source distribution and upload to pypi ('make release').
+ 1. Make sure all "Fix committed" bugs are in the 'next' milestone on
+    Launchpad
+ 1. Rename the 'next' milestone on Launchpad to 'X.Y.Z'
+ 1. Create a release on the newly-renamed 'X.Y.Z' milestone
+ 1. Upload the tarball and asc file to Launchpad
+ 1. Merge the release branch testtools-X.Y.Z into trunk. Before the commit,
+    add a NEXT heading to the top of NEWS. Push trunk to Launchpad.
+ 1. If a new series has been created (e.g. 0.10.0), make the series on Launchpad.
+ 1. Make a new milestone for the *next release*.
+    1. During release we rename NEXT to $version.
+    1. We call new milestones NEXT.
index 3296ee4c0e4f40a0602066d8f75fd1540fad1be6..6d1bf1170f259b271ba68aecfd1b0ba00366c8f8 100644 (file)
@@ -5,5 +5,4 @@ include MANIFEST.in
 include MANUAL
 include NEWS
 include README
-include run-tests
 include .bzrignore
index 1a43e70f23713fd85d51a93e1da873f9badaee4a..7e7853c7e7b148a5501bee1568c226c81a1133b5 100644 (file)
@@ -11,11 +11,12 @@ to the API docs (i.e. docstrings) for full details on a particular feature.
 Extensions to TestCase
 ----------------------
 
-Controlling test execution
-~~~~~~~~~~~~~~~~~~~~~~~~~~
+Custom exception handling
+~~~~~~~~~~~~~~~~~~~~~~~~~
 
-Testtools supports two ways to control how tests are executed. The simplest
-is to add a new exception to self.exception_handlers::
+testtools provides a way to control how test exceptions are handled.  To do
+this, add a new exception to self.exception_handlers on a TestCase.  For
+example::
 
     >>> self.exception_handlers.insert(-1, (ExceptionClass, handler)).
 
@@ -23,12 +24,36 @@ Having done this, if any of setUp, tearDown, or the test method raise
 ExceptionClass, handler will be called with the test case, test result and the
 raised exception.
 
-Secondly, by overriding __init__ to pass in runTest=RunTestFactory the whole
-execution of the test can be altered. The default is testtools.runtest.RunTest
-and calls  case._run_setup, case._run_test_method and finally
-case._run_teardown. Other methods to control what RunTest is used may be
-added in future.
+Controlling test execution
+~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+If you want to control more than just how exceptions are raised, you can
+provide a custom `RunTest` to a TestCase.  The `RunTest` object can change
+everything about how the test executes.
+
+To work with `testtools.TestCase`, a `RunTest` must have a factory that takes
+a test and an optional list of exception handlers.  Instances returned by the
+factory must have a `run()` method that takes an optional `TestResult` object.
+
+The default is `testtools.runtest.RunTest` and calls 'setUp', the test method
+and 'tearDown' in the normal, vanilla way that Python's standard unittest
+does.
+
+To specify a `RunTest` for all the tests in a `TestCase` class, do something
+like this::
+
+  class SomeTests(TestCase):
+      run_tests_with = CustomRunTestFactory
 
+To specify a `RunTest` for a specific test in a `TestCase` class, do::
+
+  class SomeTests(TestCase):
+      @run_test_with(CustomRunTestFactory, extra_arg=42, foo='whatever')
+      def test_something(self):
+          pass
+
+In addition, either of these can be overridden by passing a factory in to the
+`TestCase` constructor with the optional 'runTest' argument.
 
 TestCase.addCleanup
 ~~~~~~~~~~~~~~~~~~~
@@ -91,6 +116,16 @@ instead. ``skipTest`` was previously known as ``skip`` but as Python 2.7 adds
 ``skipTest`` support, the ``skip`` name is now deprecated (but no warning
 is emitted yet - some time in the future we may do so).
 
+TestCase.useFixture
+~~~~~~~~~~~~~~~~~~~
+
+``useFixture(fixture)`` calls setUp on the fixture, schedules a cleanup to 
+clean it up, and schedules a cleanup to attach all details held by the 
+fixture to the details dict of the test case. The fixture object should meet
+the ``fixtures.Fixture`` protocol (version 0.3.4 or newer). This is useful
+for moving code out of setUp and tearDown methods and into composable side
+classes.
+
 
 New assertion methods
 ~~~~~~~~~~~~~~~~~~~~~
@@ -115,6 +150,20 @@ asserting more things about the exception than just the type::
         self.assertEqual('bob', error.username)
         self.assertEqual('User bob cannot frobnicate', str(error))
 
+Note that this is incompatible with the assertRaises in unittest2/Python2.7.
+While we have no immediate plans to change to be compatible consider using the
+new assertThat facility instead::
+
+        self.assertThat(
+            lambda: thing.frobnicate('foo', 'bar'),
+            Raises(MatchesException(UnauthorisedError('bob')))
+
+There is also a convenience function to handle this common case::
+
+        self.assertThat(
+            lambda: thing.frobnicate('foo', 'bar'),
+            raises(UnauthorisedError('bob')))
+
 
 TestCase.assertThat
 ~~~~~~~~~~~~~~~~~~~
@@ -234,13 +283,17 @@ ThreadsafeForwardingResult to coalesce their activity.
 Running tests
 -------------
 
-Testtools provides a convenient way to run a test suite using the testtools
+testtools provides a convenient way to run a test suite using the testtools
 result object: python -m testtools.run testspec [testspec...].
 
+To run tests with Python 2.4, you'll have to do something like:
+  python2.4 /path/to/testtools/run.py testspec [testspec ...].
+
+
 Test discovery
 --------------
 
-Testtools includes a backported version of the Python 2.7 glue for using the
+testtools includes a backported version of the Python 2.7 glue for using the
 discover test discovery module. If you either have Python 2.7/3.1 or newer, or
 install the 'discover' module, then you can invoke discovery::
 
@@ -249,3 +302,48 @@ install the 'discover' module, then you can invoke discovery::
 For more information see the Python 2.7 unittest documentation, or::
 
     python -m testtools.run --help
+
+
+Twisted support
+---------------
+
+Support for running Twisted tests is very experimental right now.  You
+shouldn't really do it.  However, if you are going to, here are some tips for
+converting your Trial tests into testtools tests.
+
+ * Use the AsynchronousDeferredRunTest runner
+ * Make sure to upcall to setUp and tearDown
+ * Don't use setUpClass or tearDownClass
+ * Don't expect setting .todo, .timeout or .skip attributes to do anything
+ * flushLoggedErrors is not there for you.  Sorry.
+ * assertFailure is not there for you.  Even more sorry.
+
+
+General helpers
+---------------
+
+Lots of the time we would like to conditionally import modules.  testtools
+needs to do this itself, and graciously extends the ability to its users.
+
+Instead of::
+
+    try:
+        from twisted.internet import defer
+    except ImportError:
+        defer = None
+
+You can do::
+
+    defer = try_import('twisted.internet.defer')
+
+
+Instead of::
+
+    try:
+        from StringIO import StringIO
+    except ImportError:
+        from io import StringIO
+
+You can do::
+
+    StringIO = try_imports(['StringIO.StringIO', 'io.StringIO'])
index 0ad6f131d10846951c1917b87ed10c5374542dfd..c36fbd8012cd2cd22169aaf392d78d2675c71642 100644 (file)
@@ -16,15 +16,20 @@ clean:
        rm -f TAGS tags
        find testtools -name "*.pyc" -exec rm '{}' \;
 
-release:
+prerelease:
        # An existing MANIFEST breaks distutils sometimes. Avoid that.
        -rm MANIFEST
+
+release:
        ./setup.py sdist upload --sign
 
+snapshot: prerelease
+       ./setup.py sdist
+
 apidocs:
        pydoctor --make-html --add-package testtools \
                --docformat=restructuredtext --project-name=testtools \
                --project-url=https://launchpad.net/testtools
 
 
-.PHONY: check clean release apidocs
+.PHONY: check clean prerelease release apidocs
index 89b942fdbe37bf50306e741bb0c0d0366c602680..55193080baaf021751bace11c0dd0e51c7850be4 100644 (file)
@@ -4,6 +4,92 @@ testtools NEWS
 NEXT
 ~~~~
 
+Changes
+-------
+
+* addUnexpectedSuccess is translated to addFailure for test results that don't
+  know about addUnexpectedSuccess.  Further, it fails the entire result for
+  all testtools TestResults (i.e. wasSuccessful() returns False after
+  addUnexpectedSuccess has been called). Note that when using a delegating
+  result such as ThreadsafeForwardingResult, MultiTestResult or
+  ExtendedToOriginalDecorator then the behaviour of addUnexpectedSuccess is
+  determined by the delegated to result(s).
+  (Jonathan Lange, Robert Collins, #654474, #683332)
+
+* startTestRun will reset any errors on the result.  That is, wasSuccessful()
+  will always return True immediately after startTestRun() is called. This
+  only applies to delegated test results (ThreadsafeForwardingResult,
+  MultiTestResult and ExtendedToOriginalDecorator) if the delegated to result
+  is a testtools test result - we cannot reliably reset the state of unknown
+  test result class instances. (Jonathan Lange, Robert Collins, #683332)
+
+* Responsibility for running test cleanups has been moved to ``RunTest``.
+  This change does not affect public APIs and can be safely ignored by test
+  authors.  (Jonathan Lange, #662647)
+
+Improvements
+------------
+
+* Experimental support for running tests that return Deferreds.
+  (Jonathan Lange, Martin [gz])
+
+* Provide a per-test decorator, run_test_with, to specify which RunTest
+  object to use for a given test.  (Jonathan Lange, #657780)
+
+* Fix the runTest parameter of TestCase to actually work, rather than raising
+  a TypeError.  (Jonathan Lange, #657760)
+
+* New matcher ``EndsWith`` added to complement the existing ``StartsWith``
+  matcher.  (Jonathan Lange, #669165)
+
+* Non-release snapshots of testtools will now work with buildout.
+  (Jonathan Lange, #613734)
+
+* Malformed SyntaxErrors no longer blow up the test suite.  (Martin [gz])
+
+* ``MatchesException`` added to the ``testtools.matchers`` module - matches
+  an exception class and parameters. (Robert Collins)
+
+* New ``KeysEqual`` matcher.  (Jonathan Lange)
+
+* New helpers for conditionally importing modules, ``try_import`` and
+  ``try_imports``.  (Jonathan Lange)
+
+* ``Raises`` added to the ``testtools.matchers`` module - matches if the
+  supplied callable raises, and delegates to an optional matcher for validation
+  of the exception. (Robert Collins)
+
+* ``raises`` added to the ``testtools.matchers`` module - matches if the
+  supplied callable raises and delegates to ``MatchesException`` to validate
+  the exception. (Jonathan Lange)
+
+* ``testools.TestCase.useFixture`` has been added to glue with fixtures nicely.
+  (Robert Collins)
+
+* ``testtools.run`` now supports ``-l`` to list tests rather than executing
+  them. This is useful for integration with external test analysis/processing
+  tools like subunit and testrepository. (Robert Collins)
+
+* ``testtools.run`` now supports ``--load-list``, which takes a file containing
+  test ids, one per line, and intersects those ids with the tests found. This
+  allows fine grained control of what tests are run even when the tests cannot
+  be named as objects to import (e.g. due to test parameterisation via
+  testscenarios). (Robert Collins)
+
+* Update documentation to say how to use testtools.run() on Python 2.4.
+  (Jonathan Lange, #501174)
+
+* ``text_content`` conveniently converts a Python string to a Content object.
+  (Jonathan Lange, James Westby)
+
+
+
+0.9.7
+~~~~~
+
+Lots of little cleanups in this release; many small improvements to make your
+testing life more pleasant.
+
 Improvements
 ------------
 
@@ -16,6 +102,22 @@ Improvements
 * In normal circumstances, a TestCase will no longer share details with clones
   of itself. (Andrew Bennetts, bug #637725)
 
+* Less exception object cycles are generated (reduces peak memory use between
+  garbage collection). (Martin [gz])
+
+* New matchers 'DoesNotStartWith' and 'StartsWith' contributed by Canonical
+  from the Launchpad project. Written by James Westby.
+
+* Timestamps as produced by subunit protocol clients are now forwarded in the
+  ThreadsafeForwardingResult so correct test durations can be reported.
+  (Martin [gz], Robert Collins, #625594)
+
+* With unittest from Python 2.7 skipped tests will now show only the reason
+  rather than a serialisation of all details. (Martin [gz], #625583)
+
+* The testtools release process is now a little better documented and a little
+  smoother.  (Jonathan Lange, #623483, #623487)
+
 
 0.9.6
 ~~~~~
index 991f3d5a0629c54a81778e8e1f39f0e82125fb5c..83120f01e4f3b33c9aa97a6ab524402c6a4ef14b 100644 (file)
@@ -19,11 +19,21 @@ is copyright Steve Purcell and the Python Software Foundation, it is
 distributed under the same license as Python, see LICENSE for details.
 
 
-Dependencies
-------------
+Required Dependencies
+---------------------
 
  * Python 2.4+ or 3.0+
 
+Optional Dependencies
+---------------------
+
+If you would like to use our undocumented, unsupported Twisted support, then
+you will need Twisted.
+
+If you want to use ``fixtures`` then you can either install fixtures (e.g. from
+https://launchpad.net/python-fixtures or http://pypi.python.org/pypi/fixtures)
+or alternatively just make sure your fixture objects obey the same protocol.
+
 
 Bug reports and patches
 -----------------------
@@ -56,3 +66,7 @@ Thanks
  * Robert Collins
  * Andrew Bennetts
  * Benjamin Peterson
+ * Jamu Kakar
+ * James Westby
+ * Martin [gz]
+ * Michael Hudson-Doyle
index d7ed46f79f6cf4c70e5075bdf73ac35f7facd015..59e5804f05bd3ec1c40b95f362dc24e76dc6c9a5 100755 (executable)
@@ -2,18 +2,55 @@
 """Distutils installer for testtools."""
 
 from distutils.core import setup
+import email
+import os
+
 import testtools
-version = '.'.join(str(component) for component in testtools.__version__[0:3])
-phase = testtools.__version__[3]
-if phase != 'final':
+
+
+def get_revno():
     import bzrlib.workingtree
     t = bzrlib.workingtree.WorkingTree.open_containing(__file__)[0]
+    return t.branch.revno()
+
+
+def get_version_from_pkg_info():
+    """Get the version from PKG-INFO file if we can."""
+    pkg_info_path = os.path.join(os.path.dirname(__file__), 'PKG-INFO')
+    try:
+        pkg_info_file = open(pkg_info_path, 'r')
+    except (IOError, OSError):
+        return None
+    try:
+        pkg_info = email.message_from_file(pkg_info_file)
+    except email.MessageError:
+        return None
+    return pkg_info.get('Version', None)
+
+
+def get_version():
+    """Return the version of testtools that we are building."""
+    version = '.'.join(
+        str(component) for component in testtools.__version__[0:3])
+    phase = testtools.__version__[3]
+    if phase == 'final':
+        return version
+    pkg_info_version = get_version_from_pkg_info()
+    if pkg_info_version:
+        return pkg_info_version
+    revno = get_revno()
     if phase == 'alpha':
         # No idea what the next version will be
-        version = 'next-%s' % t.branch.revno()
+        return 'next-r%s' % revno
     else:
         # Preserve the version number but give it a revno prefix
-        version = version + '~%s' % t.branch.revno()
+        return version + '-r%s' % revno
+
+
+def get_long_description():
+    manual_path = os.path.join(os.path.dirname(__file__), 'MANUAL')
+    return open(manual_path).read()
+
 
 setup(name='testtools',
       author='Jonathan M. Lange',
@@ -21,5 +58,7 @@ setup(name='testtools',
       url='https://launchpad.net/testtools',
       description=('Extensions to the Python standard library unit testing '
                    'framework'),
-      version=version,
+      long_description=get_long_description(),
+      version=get_version(),
+      classifiers=["License :: OSI Approved :: MIT License"],
       packages=['testtools', 'testtools.testresult', 'testtools.tests'])
index 2b76a5eef7ff0556ee213df66f7d4bdeb060681b..0f85426aa70e337f812e8e8b4f9902226a4412b2 100644 (file)
@@ -11,6 +11,7 @@ __all__ = [
     'MultipleExceptions',
     'MultiTestResult',
     'PlaceHolder',
+    'run_test_with',
     'TestCase',
     'TestResult',
     'TextTestResult',
@@ -19,20 +20,27 @@ __all__ = [
     'skipIf',
     'skipUnless',
     'ThreadsafeForwardingResult',
+    'try_import',
+    'try_imports',
     ]
 
+from testtools.helpers import (
+    try_import,
+    try_imports,
+    )
 from testtools.matchers import (
     Matcher,
     )
 from testtools.runtest import (
+    MultipleExceptions,
     RunTest,
     )
 from testtools.testcase import (
     ErrorHolder,
-    MultipleExceptions,
     PlaceHolder,
     TestCase,
     clone_test_with_new_id,
+    run_test_with,
     skip,
     skipIf,
     skipUnless,
@@ -61,4 +69,4 @@ from testtools.testsuite import (
 # If the releaselevel is 'final', then the tarball will be major.minor.micro.
 # Otherwise it is major.minor.micro~$(revno).
 
-__version__ = (0, 9, 7, 'dev', 0)
+__version__ = (0, 9, 8, 'dev', 0)
diff --git a/lib/testtools/testtools/_spinner.py b/lib/testtools/testtools/_spinner.py
new file mode 100644 (file)
index 0000000..eced554
--- /dev/null
@@ -0,0 +1,317 @@
+# Copyright (c) 2010 Jonathan M. Lange. See LICENSE for details.
+
+"""Evil reactor-spinning logic for running Twisted tests.
+
+This code is highly experimental, liable to change and not to be trusted.  If
+you couldn't write this yourself, you should not be using it.
+"""
+
+__all__ = [
+    'DeferredNotFired',
+    'extract_result',
+    'NoResultError',
+    'not_reentrant',
+    'ReentryError',
+    'Spinner',
+    'StaleJunkError',
+    'TimeoutError',
+    'trap_unhandled_errors',
+    ]
+
+import signal
+
+from testtools.monkey import MonkeyPatcher
+
+from twisted.internet import defer
+from twisted.internet.base import DelayedCall
+from twisted.internet.interfaces import IReactorThreads
+from twisted.python.failure import Failure
+from twisted.python.util import mergeFunctionMetadata
+
+
+class ReentryError(Exception):
+    """Raised when we try to re-enter a function that forbids it."""
+
+    def __init__(self, function):
+        Exception.__init__(self,
+            "%r in not re-entrant but was called within a call to itself."
+            % (function,))
+
+
+def not_reentrant(function, _calls={}):
+    """Decorates a function as not being re-entrant.
+
+    The decorated function will raise an error if called from within itself.
+    """
+    def decorated(*args, **kwargs):
+        if _calls.get(function, False):
+            raise ReentryError(function)
+        _calls[function] = True
+        try:
+            return function(*args, **kwargs)
+        finally:
+            _calls[function] = False
+    return mergeFunctionMetadata(function, decorated)
+
+
+class DeferredNotFired(Exception):
+    """Raised when we extract a result from a Deferred that's not fired yet."""
+
+
+def extract_result(deferred):
+    """Extract the result from a fired deferred.
+
+    It can happen that you have an API that returns Deferreds for
+    compatibility with Twisted code, but is in fact synchronous, i.e. the
+    Deferreds it returns have always fired by the time it returns.  In this
+    case, you can use this function to convert the result back into the usual
+    form for a synchronous API, i.e. the result itself or a raised exception.
+
+    It would be very bad form to use this as some way of checking if a
+    Deferred has fired.
+    """
+    failures = []
+    successes = []
+    deferred.addCallbacks(successes.append, failures.append)
+    if len(failures) == 1:
+        failures[0].raiseException()
+    elif len(successes) == 1:
+        return successes[0]
+    else:
+        raise DeferredNotFired("%r has not fired yet." % (deferred,))
+
+
+def trap_unhandled_errors(function, *args, **kwargs):
+    """Run a function, trapping any unhandled errors in Deferreds.
+
+    Assumes that 'function' will have handled any errors in Deferreds by the
+    time it is complete.  This is almost never true of any Twisted code, since
+    you can never tell when someone has added an errback to a Deferred.
+
+    If 'function' raises, then don't bother doing any unhandled error
+    jiggery-pokery, since something horrible has probably happened anyway.
+
+    :return: A tuple of '(result, error)', where 'result' is the value returned
+        by 'function' and 'error' is a list of `defer.DebugInfo` objects that
+        have unhandled errors in Deferreds.
+    """
+    real_DebugInfo = defer.DebugInfo
+    debug_infos = []
+    def DebugInfo():
+        info = real_DebugInfo()
+        debug_infos.append(info)
+        return info
+    defer.DebugInfo = DebugInfo
+    try:
+        result = function(*args, **kwargs)
+    finally:
+        defer.DebugInfo = real_DebugInfo
+    errors = []
+    for info in debug_infos:
+        if info.failResult is not None:
+            errors.append(info)
+            # Disable the destructor that logs to error. We are already
+            # catching the error here.
+            info.__del__ = lambda: None
+    return result, errors
+
+
+class TimeoutError(Exception):
+    """Raised when run_in_reactor takes too long to run a function."""
+
+    def __init__(self, function, timeout):
+        Exception.__init__(self,
+            "%r took longer than %s seconds" % (function, timeout))
+
+
+class NoResultError(Exception):
+    """Raised when the reactor has stopped but we don't have any result."""
+
+    def __init__(self):
+        Exception.__init__(self,
+            "Tried to get test's result from Deferred when no result is "
+            "available.  Probably means we received SIGINT or similar.")
+
+
+class StaleJunkError(Exception):
+    """Raised when there's junk in the spinner from a previous run."""
+
+    def __init__(self, junk):
+        Exception.__init__(self,
+            "There was junk in the spinner from a previous run. "
+            "Use clear_junk() to clear it out: %r" % (junk,))
+
+
+class Spinner(object):
+    """Spin the reactor until a function is done.
+
+    This class emulates the behaviour of twisted.trial in that it grotesquely
+    and horribly spins the Twisted reactor while a function is running, and
+    then kills the reactor when that function is complete and all the
+    callbacks in its chains are done.
+    """
+
+    _UNSET = object()
+
+    # Signals that we save and restore for each spin.
+    _PRESERVED_SIGNALS = [
+        'SIGINT',
+        'SIGTERM',
+        'SIGCHLD',
+        ]
+
+    # There are many APIs within Twisted itself where a Deferred fires but
+    # leaves cleanup work scheduled for the reactor to do.  Arguably, many of
+    # these are bugs.  As such, we provide a facility to iterate the reactor
+    # event loop a number of times after every call, in order to shake out
+    # these buggy-but-commonplace events.  The default is 0, because that is
+    # the ideal, and it actually works for many cases.
+    _OBLIGATORY_REACTOR_ITERATIONS = 0
+
+    def __init__(self, reactor, debug=False):
+        """Construct a Spinner.
+
+        :param reactor: A Twisted reactor.
+        :param debug: Whether or not to enable Twisted's debugging.  Defaults
+            to False.
+        """
+        self._reactor = reactor
+        self._timeout_call = None
+        self._success = self._UNSET
+        self._failure = self._UNSET
+        self._saved_signals = []
+        self._junk = []
+        self._debug = debug
+
+    def _cancel_timeout(self):
+        if self._timeout_call:
+            self._timeout_call.cancel()
+
+    def _get_result(self):
+        if self._failure is not self._UNSET:
+            self._failure.raiseException()
+        if self._success is not self._UNSET:
+            return self._success
+        raise NoResultError()
+
+    def _got_failure(self, result):
+        self._cancel_timeout()
+        self._failure = result
+
+    def _got_success(self, result):
+        self._cancel_timeout()
+        self._success = result
+
+    def _stop_reactor(self, ignored=None):
+        """Stop the reactor!"""
+        self._reactor.crash()
+
+    def _timed_out(self, function, timeout):
+        e = TimeoutError(function, timeout)
+        self._failure = Failure(e)
+        self._stop_reactor()
+
+    def _clean(self):
+        """Clean up any junk in the reactor.
+
+        Will always iterate the reactor a number of times equal to
+        `_OBLIGATORY_REACTOR_ITERATIONS`.  This is to work around bugs in
+        various Twisted APIs where a Deferred fires but still leaves work
+        (e.g. cancelling a call, actually closing a connection) for the
+        reactor to do.
+        """
+        for i in range(self._OBLIGATORY_REACTOR_ITERATIONS):
+            self._reactor.iterate(0)
+        junk = []
+        for delayed_call in self._reactor.getDelayedCalls():
+            delayed_call.cancel()
+            junk.append(delayed_call)
+        for selectable in self._reactor.removeAll():
+            # Twisted sends a 'KILL' signal to selectables that provide
+            # IProcessTransport.  Since only _dumbwin32proc processes do this,
+            # we aren't going to bother.
+            junk.append(selectable)
+        if IReactorThreads.providedBy(self._reactor):
+            self._reactor.suggestThreadPoolSize(0)
+            if self._reactor.threadpool is not None:
+                self._reactor._stopThreadPool()
+        self._junk.extend(junk)
+        return junk
+
+    def clear_junk(self):
+        """Clear out our recorded junk.
+
+        :return: Whatever junk was there before.
+        """
+        junk = self._junk
+        self._junk = []
+        return junk
+
+    def get_junk(self):
+        """Return any junk that has been found on the reactor."""
+        return self._junk
+
+    def _save_signals(self):
+        available_signals = [
+            getattr(signal, name, None) for name in self._PRESERVED_SIGNALS]
+        self._saved_signals = [
+            (sig, signal.getsignal(sig)) for sig in available_signals if sig]
+
+    def _restore_signals(self):
+        for sig, hdlr in self._saved_signals:
+            signal.signal(sig, hdlr)
+        self._saved_signals = []
+
+    @not_reentrant
+    def run(self, timeout, function, *args, **kwargs):
+        """Run 'function' in a reactor.
+
+        If 'function' returns a Deferred, the reactor will keep spinning until
+        the Deferred fires and its chain completes or until the timeout is
+        reached -- whichever comes first.
+
+        :raise TimeoutError: If 'timeout' is reached before the `Deferred`
+            returned by 'function' has completed its callback chain.
+        :raise NoResultError: If the reactor is somehow interrupted before
+            the `Deferred` returned by 'function' has completed its callback
+            chain.
+        :raise StaleJunkError: If there's junk in the spinner from a previous
+            run.
+        :return: Whatever is at the end of the function's callback chain.  If
+            it's an error, then raise that.
+        """
+        debug = MonkeyPatcher()
+        if self._debug:
+            debug.add_patch(defer.Deferred, 'debug', True)
+            debug.add_patch(DelayedCall, 'debug', True)
+        debug.patch()
+        try:
+            junk = self.get_junk()
+            if junk:
+                raise StaleJunkError(junk)
+            self._save_signals()
+            self._timeout_call = self._reactor.callLater(
+                timeout, self._timed_out, function, timeout)
+            # Calling 'stop' on the reactor will make it impossible to
+            # re-start the reactor.  Since the default signal handlers for
+            # TERM, BREAK and INT all call reactor.stop(), we'll patch it over
+            # with crash.  XXX: It might be a better idea to either install
+            # custom signal handlers or to override the methods that are
+            # Twisted's signal handlers.
+            stop, self._reactor.stop = self._reactor.stop, self._reactor.crash
+            def run_function():
+                d = defer.maybeDeferred(function, *args, **kwargs)
+                d.addCallbacks(self._got_success, self._got_failure)
+                d.addBoth(self._stop_reactor)
+            try:
+                self._reactor.callWhenRunning(run_function)
+                self._reactor.run()
+            finally:
+                self._reactor.stop = stop
+                self._restore_signals()
+            try:
+                return self._get_result()
+            finally:
+                self._clean()
+        finally:
+            debug.restore()
index 0dd2fe8bf9ecf8049db79a48772bdd8aa32d964b..1f0b8cfe8549a7418bb4fb1f71b53986620c4c68 100644 (file)
@@ -209,38 +209,43 @@ def _format_exc_info(eclass, evalue, tb, limit=None):
         list = []
     if evalue is None:
         # Is a (deprecated) string exception
-        list.append(eclass.decode("ascii", "replace"))
-    elif isinstance(evalue, SyntaxError) and len(evalue.args) > 1:
+        list.append((eclass + "\n").decode("ascii", "replace"))
+        return list
+    if isinstance(evalue, SyntaxError):
         # Avoid duplicating the special formatting for SyntaxError here,
         # instead create a new instance with unicode filename and line
         # Potentially gives duff spacing, but that's a pre-existing issue
-        filename, lineno, offset, line = evalue.args[1]
-        if line:
+        try:
+            msg, (filename, lineno, offset, line) = evalue
+        except (TypeError, ValueError):
+            pass # Strange exception instance, fall through to generic code
+        else:
             # Errors during parsing give the line from buffer encoded as
             # latin-1 or utf-8 or the encoding of the file depending on the
             # coding and whether the patch for issue #1031213 is applied, so
             # give up on trying to decode it and just read the file again
-            bytestr = linecache.getline(filename, lineno)
-            if bytestr:
-                if lineno == 1 and bytestr.startswith("\xef\xbb\xbf"):
-                    bytestr = bytestr[3:]
-                line = bytestr.decode(_get_source_encoding(filename), "replace")
-                del linecache.cache[filename]
-            else:
-                line = line.decode("ascii", "replace")
-        if filename:
-            filename = filename.decode(fs_enc, "replace")
-        evalue = eclass(evalue.args[0], (filename, lineno, offset, line))
-        list.extend(traceback.format_exception_only(eclass, evalue))
+            if line:
+                bytestr = linecache.getline(filename, lineno)
+                if bytestr:
+                    if lineno == 1 and bytestr.startswith("\xef\xbb\xbf"):
+                        bytestr = bytestr[3:]
+                    line = bytestr.decode(
+                        _get_source_encoding(filename), "replace")
+                    del linecache.cache[filename]
+                else:
+                    line = line.decode("ascii", "replace")
+            if filename:
+                filename = filename.decode(fs_enc, "replace")
+            evalue = eclass(msg, (filename, lineno, offset, line))
+            list.extend(traceback.format_exception_only(eclass, evalue))
+            return list
+    sclass = eclass.__name__
+    svalue = _exception_to_text(evalue)
+    if svalue:
+        list.append("%s: %s\n" % (sclass, svalue))
+    elif svalue is None:
+        # GZ 2010-05-24: Not a great fallback message, but keep for the moment
+        list.append("%s: <unprintable %s object>\n" % (sclass, sclass))
     else:
-        sclass = eclass.__name__
-        svalue = _exception_to_text(evalue)
-        if svalue:
-            list.append("%s: %s\n" % (sclass, svalue))
-        elif svalue is None:
-            # GZ 2010-05-24: Not a great fallback message, but keep for the
-            #                the same for compatibility for the moment
-            list.append("%s: <unprintable %s object>\n" % (sclass, sclass))
-        else:
-            list.append("%s\n" % sclass)
+        list.append("%s\n" % sclass)
     return list
index 843133aa1aa86801add19bac1d4f8aede11a190c..86df09fc6e4fd2cd21a2ba8166270331a0cd2511 100644 (file)
@@ -5,10 +5,13 @@
 import codecs
 
 from testtools.compat import _b
-from testtools.content_type import ContentType
+from testtools.content_type import ContentType, UTF8_TEXT
 from testtools.testresult import TestResult
 
 
+_join_b = _b("").join
+
+
 class Content(object):
     """A MIME-like Content object.
 
@@ -31,7 +34,7 @@ class Content(object):
 
     def __eq__(self, other):
         return (self.content_type == other.content_type and
-            ''.join(self.iter_bytes()) == ''.join(other.iter_bytes()))
+            _join_b(self.iter_bytes()) == _join_b(other.iter_bytes()))
 
     def iter_bytes(self):
         """Iterate over bytestrings of the serialised content."""
@@ -68,7 +71,7 @@ class Content(object):
 
     def __repr__(self):
         return "<Content type=%r, value=%r>" % (
-            self.content_type, ''.join(self.iter_bytes()))
+            self.content_type, _join_b(self.iter_bytes()))
 
 
 class TracebackContent(Content):
@@ -89,3 +92,11 @@ class TracebackContent(Content):
         value = self._result._exc_info_to_unicode(err, test)
         super(TracebackContent, self).__init__(
             content_type, lambda: [value.encode("utf8")])
+
+
+def text_content(text):
+    """Create a `Content` object from some text.
+
+    This is useful for adding details which are short strings.
+    """
+    return Content(UTF8_TEXT, lambda: [text.encode('utf8')])
diff --git a/lib/testtools/testtools/deferredruntest.py b/lib/testtools/testtools/deferredruntest.py
new file mode 100644 (file)
index 0000000..50153be
--- /dev/null
@@ -0,0 +1,336 @@
+# Copyright (c) 2010 Jonathan M. Lange. See LICENSE for details.
+
+"""Individual test case execution for tests that return Deferreds.
+
+This module is highly experimental and is liable to change in ways that cause
+subtle failures in tests.  Use at your own peril.
+"""
+
+__all__ = [
+    'assert_fails_with',
+    'AsynchronousDeferredRunTest',
+    'AsynchronousDeferredRunTestForBrokenTwisted',
+    'SynchronousDeferredRunTest',
+    ]
+
+import sys
+
+from testtools import try_imports
+from testtools.content import (
+    Content,
+    text_content,
+    )
+from testtools.content_type import UTF8_TEXT
+from testtools.runtest import RunTest
+from testtools._spinner import (
+    extract_result,
+    NoResultError,
+    Spinner,
+    TimeoutError,
+    trap_unhandled_errors,
+    )
+
+from twisted.internet import defer
+from twisted.python import log
+from twisted.trial.unittest import _LogObserver
+
+StringIO = try_imports(['StringIO.StringIO', 'io.StringIO'])
+
+
+class _DeferredRunTest(RunTest):
+    """Base for tests that return Deferreds."""
+
+    def _got_user_failure(self, failure, tb_label='traceback'):
+        """We got a failure from user code."""
+        return self._got_user_exception(
+            (failure.type, failure.value, failure.getTracebackObject()),
+            tb_label=tb_label)
+
+
+class SynchronousDeferredRunTest(_DeferredRunTest):
+    """Runner for tests that return synchronous Deferreds."""
+
+    def _run_user(self, function, *args):
+        d = defer.maybeDeferred(function, *args)
+        d.addErrback(self._got_user_failure)
+        result = extract_result(d)
+        return result
+
+
+def run_with_log_observers(observers, function, *args, **kwargs):
+    """Run 'function' with the given Twisted log observers."""
+    real_observers = log.theLogPublisher.observers
+    for observer in real_observers:
+        log.theLogPublisher.removeObserver(observer)
+    for observer in observers:
+        log.theLogPublisher.addObserver(observer)
+    try:
+        return function(*args, **kwargs)
+    finally:
+        for observer in observers:
+            log.theLogPublisher.removeObserver(observer)
+        for observer in real_observers:
+            log.theLogPublisher.addObserver(observer)
+
+
+# Observer of the Twisted log that we install during tests.
+_log_observer = _LogObserver()
+
+
+
+class AsynchronousDeferredRunTest(_DeferredRunTest):
+    """Runner for tests that return Deferreds that fire asynchronously.
+
+    That is, this test runner assumes that the Deferreds will only fire if the
+    reactor is left to spin for a while.
+
+    Do not rely too heavily on the nuances of the behaviour of this class.
+    What it does to the reactor is black magic, and if we can find nicer ways
+    of doing it we will gladly break backwards compatibility.
+
+    This is highly experimental code.  Use at your own risk.
+    """
+
+    def __init__(self, case, handlers=None, reactor=None, timeout=0.005,
+                 debug=False):
+        """Construct an `AsynchronousDeferredRunTest`.
+
+        :param case: The `testtools.TestCase` to run.
+        :param handlers: A list of exception handlers (ExceptionType, handler)
+            where 'handler' is a callable that takes a `TestCase`, a
+            `TestResult` and the exception raised.
+        :param reactor: The Twisted reactor to use.  If not given, we use the
+            default reactor.
+        :param timeout: The maximum time allowed for running a test.  The
+            default is 0.005s.
+        :param debug: Whether or not to enable Twisted's debugging.  Use this
+            to get information about unhandled Deferreds and left-over
+            DelayedCalls.  Defaults to False.
+        """
+        super(AsynchronousDeferredRunTest, self).__init__(case, handlers)
+        if reactor is None:
+            from twisted.internet import reactor
+        self._reactor = reactor
+        self._timeout = timeout
+        self._debug = debug
+
+    @classmethod
+    def make_factory(cls, reactor=None, timeout=0.005, debug=False):
+        """Make a factory that conforms to the RunTest factory interface."""
+        # This is horrible, but it means that the return value of the method
+        # will be able to be assigned to a class variable *and* also be
+        # invoked directly.
+        class AsynchronousDeferredRunTestFactory:
+            def __call__(self, case, handlers=None):
+                return cls(case, handlers, reactor, timeout, debug)
+        return AsynchronousDeferredRunTestFactory()
+
+    @defer.deferredGenerator
+    def _run_cleanups(self):
+        """Run the cleanups on the test case.
+
+        We expect that the cleanups on the test case can also return
+        asynchronous Deferreds.  As such, we take the responsibility for
+        running the cleanups, rather than letting TestCase do it.
+        """
+        while self.case._cleanups:
+            f, args, kwargs = self.case._cleanups.pop()
+            d = defer.maybeDeferred(f, *args, **kwargs)
+            thing = defer.waitForDeferred(d)
+            yield thing
+            try:
+                thing.getResult()
+            except Exception:
+                exc_info = sys.exc_info()
+                self.case._report_traceback(exc_info)
+                last_exception = exc_info[1]
+        yield last_exception
+
+    def _make_spinner(self):
+        """Make the `Spinner` to be used to run the tests."""
+        return Spinner(self._reactor, debug=self._debug)
+
+    def _run_deferred(self):
+        """Run the test, assuming everything in it is Deferred-returning.
+
+        This should return a Deferred that fires with True if the test was
+        successful and False if the test was not successful.  It should *not*
+        call addSuccess on the result, because there's reactor clean up that
+        we needs to be done afterwards.
+        """
+        fails = []
+
+        def fail_if_exception_caught(exception_caught):
+            if self.exception_caught == exception_caught:
+                fails.append(None)
+
+        def clean_up(ignored=None):
+            """Run the cleanups."""
+            d = self._run_cleanups()
+            def clean_up_done(result):
+                if result is not None:
+                    self._exceptions.append(result)
+                    fails.append(None)
+            return d.addCallback(clean_up_done)
+
+        def set_up_done(exception_caught):
+            """Set up is done, either clean up or run the test."""
+            if self.exception_caught == exception_caught:
+                fails.append(None)
+                return clean_up()
+            else:
+                d = self._run_user(self.case._run_test_method, self.result)
+                d.addCallback(fail_if_exception_caught)
+                d.addBoth(tear_down)
+                return d
+
+        def tear_down(ignored):
+            d = self._run_user(self.case._run_teardown, self.result)
+            d.addCallback(fail_if_exception_caught)
+            d.addBoth(clean_up)
+            return d
+
+        d = self._run_user(self.case._run_setup, self.result)
+        d.addCallback(set_up_done)
+        d.addBoth(lambda ignored: len(fails) == 0)
+        return d
+
+    def _log_user_exception(self, e):
+        """Raise 'e' and report it as a user exception."""
+        try:
+            raise e
+        except e.__class__:
+            self._got_user_exception(sys.exc_info())
+
+    def _blocking_run_deferred(self, spinner):
+        try:
+            return trap_unhandled_errors(
+                spinner.run, self._timeout, self._run_deferred)
+        except NoResultError:
+            # We didn't get a result at all!  This could be for any number of
+            # reasons, but most likely someone hit Ctrl-C during the test.
+            raise KeyboardInterrupt
+        except TimeoutError:
+            # The function took too long to run.
+            self._log_user_exception(TimeoutError(self.case, self._timeout))
+            return False, []
+
+    def _run_core(self):
+        # Add an observer to trap all logged errors.
+        error_observer = _log_observer
+        full_log = StringIO()
+        full_observer = log.FileLogObserver(full_log)
+        spinner = self._make_spinner()
+        successful, unhandled = run_with_log_observers(
+            [error_observer.gotEvent, full_observer.emit],
+            self._blocking_run_deferred, spinner)
+
+        self.case.addDetail(
+            'twisted-log', Content(UTF8_TEXT, full_log.readlines))
+
+        logged_errors = error_observer.flushErrors()
+        for logged_error in logged_errors:
+            successful = False
+            self._got_user_failure(logged_error, tb_label='logged-error')
+
+        if unhandled:
+            successful = False
+            for debug_info in unhandled:
+                f = debug_info.failResult
+                info = debug_info._getDebugTracebacks()
+                if info:
+                    self.case.addDetail(
+                        'unhandled-error-in-deferred-debug',
+                        text_content(info))
+                self._got_user_failure(f, 'unhandled-error-in-deferred')
+
+        junk = spinner.clear_junk()
+        if junk:
+            successful = False
+            self._log_user_exception(UncleanReactorError(junk))
+
+        if successful:
+            self.result.addSuccess(self.case, details=self.case.getDetails())
+
+    def _run_user(self, function, *args):
+        """Run a user-supplied function.
+
+        This just makes sure that it returns a Deferred, regardless of how the
+        user wrote it.
+        """
+        d = defer.maybeDeferred(function, *args)
+        return d.addErrback(self._got_user_failure)
+
+
+class AsynchronousDeferredRunTestForBrokenTwisted(AsynchronousDeferredRunTest):
+    """Test runner that works around Twisted brokenness re reactor junk.
+
+    There are many APIs within Twisted itself where a Deferred fires but
+    leaves cleanup work scheduled for the reactor to do.  Arguably, many of
+    these are bugs.  This runner iterates the reactor event loop a number of
+    times after every test, in order to shake out these buggy-but-commonplace
+    events.
+    """
+
+    def _make_spinner(self):
+        spinner = super(
+            AsynchronousDeferredRunTestForBrokenTwisted, self)._make_spinner()
+        spinner._OBLIGATORY_REACTOR_ITERATIONS = 2
+        return spinner
+
+
+def assert_fails_with(d, *exc_types, **kwargs):
+    """Assert that 'd' will fail with one of 'exc_types'.
+
+    The normal way to use this is to return the result of 'assert_fails_with'
+    from your unit test.
+
+    Note that this function is experimental and unstable.  Use at your own
+    peril; expect the API to change.
+
+    :param d: A Deferred that is expected to fail.
+    :param *exc_types: The exception types that the Deferred is expected to
+        fail with.
+    :param failureException: An optional keyword argument.  If provided, will
+        raise that exception instead of `testtools.TestCase.failureException`.
+    :return: A Deferred that will fail with an `AssertionError` if 'd' does
+        not fail with one of the exception types.
+    """
+    failureException = kwargs.pop('failureException', None)
+    if failureException is None:
+        # Avoid circular imports.
+        from testtools import TestCase
+        failureException = TestCase.failureException
+    expected_names = ", ".join(exc_type.__name__ for exc_type in exc_types)
+    def got_success(result):
+        raise failureException(
+            "%s not raised (%r returned)" % (expected_names, result))
+    def got_failure(failure):
+        if failure.check(*exc_types):
+            return failure.value
+        raise failureException("%s raised instead of %s:\n %s" % (
+            failure.type.__name__, expected_names, failure.getTraceback()))
+    return d.addCallbacks(got_success, got_failure)
+
+
+def flush_logged_errors(*error_types):
+    return _log_observer.flushErrors(*error_types)
+
+
+class UncleanReactorError(Exception):
+    """Raised when the reactor has junk in it."""
+
+    def __init__(self, junk):
+        Exception.__init__(self,
+            "The reactor still thinks it needs to do things. Close all "
+            "connections, kill all processes and make sure all delayed "
+            "calls have either fired or been cancelled:\n%s"
+            % ''.join(map(self._get_junk_info, junk)))
+
+    def _get_junk_info(self, junk):
+        from twisted.internet.base import DelayedCall
+        if isinstance(junk, DelayedCall):
+            ret = str(junk)
+        else:
+            ret = repr(junk)
+        return '  %s\n' % (ret,)
diff --git a/lib/testtools/testtools/helpers.py b/lib/testtools/testtools/helpers.py
new file mode 100644 (file)
index 0000000..0f489c7
--- /dev/null
@@ -0,0 +1,64 @@
+# Copyright (c) 2010 Jonathan M. Lange. See LICENSE for details.
+
+__all__ = [
+    'try_import',
+    'try_imports',
+    ]
+
+
+def try_import(name, alternative=None):
+    """Attempt to import `name`.  If it fails, return `alternative`.
+
+    When supporting multiple versions of Python or optional dependencies, it
+    is useful to be able to try to import a module.
+
+    :param name: The name of the object to import, e.g. 'os.path' or
+        'os.path.join'.
+    :param alternative: The value to return if no module can be imported.
+        Defaults to None.
+    """
+    module_segments = name.split('.')
+    while module_segments:
+        module_name = '.'.join(module_segments)
+        try:
+            module = __import__(module_name)
+        except ImportError:
+            module_segments.pop()
+            continue
+        else:
+            break
+    else:
+        return alternative
+    nonexistent = object()
+    for segment in name.split('.')[1:]:
+        module = getattr(module, segment, nonexistent)
+        if module is nonexistent:
+            return alternative
+    return module
+
+
+_RAISE_EXCEPTION = object()
+def try_imports(module_names, alternative=_RAISE_EXCEPTION):
+    """Attempt to import modules.
+
+    Tries to import the first module in `module_names`.  If it can be
+    imported, we return it.  If not, we go on to the second module and try
+    that.  The process continues until we run out of modules to try.  If none
+    of the modules can be imported, either raise an exception or return the
+    provided `alternative` value.
+
+    :param module_names: A sequence of module names to try to import.
+    :param alternative: The value to return if no module can be imported.
+        If unspecified, we raise an ImportError.
+    :raises ImportError: If none of the modules can be imported and no
+        alternative value was specified.
+    """
+    module_names = list(module_names)
+    for module_name in module_names:
+        module = try_import(module_name)
+        if module:
+            return module
+    if alternative is _RAISE_EXCEPTION:
+        raise ImportError(
+            "Could not import any of: %s" % ', '.join(module_names))
+    return alternative
index 61b5bd74f95381bd6e7da2809ba77484067b787c..50cc50d31df9ad62578547c3fcb1fd1a1befd791 100644 (file)
@@ -1,4 +1,4 @@
-# Copyright (c) 2009 Jonathan M. Lange. See LICENSE for details.
+# Copyright (c) 2009-2010 Jonathan M. Lange. See LICENSE for details.
 
 """Matchers, a way to express complex assertions outside the testcase.
 
@@ -19,13 +19,18 @@ __all__ = [
     'LessThan',
     'MatchesAll',
     'MatchesAny',
+    'MatchesException',
     'NotEquals',
     'Not',
+    'Raises',
+    'raises',
+    'StartsWith',
     ]
 
 import doctest
 import operator
 from pprint import pformat
+import sys
 
 
 class Matcher(object):
@@ -101,6 +106,10 @@ class Mismatch(object):
         """
         return getattr(self, '_details', {})
 
+    def __repr__(self):
+        return  "<testtools.matchers.Mismatch object at %x attributes=%r>" % (
+            id(self), self.__dict__)
+
 
 class DocTestMatches(object):
     """See if a string matches a doctest example."""
@@ -152,6 +161,39 @@ class DocTestMismatch(Mismatch):
         return self.matcher._describe_difference(self.with_nl)
 
 
+class DoesNotStartWith(Mismatch):
+
+    def __init__(self, matchee, expected):
+        """Create a DoesNotStartWith Mismatch.
+
+        :param matchee: the string that did not match.
+        :param expected: the string that `matchee` was expected to start
+            with.
+        """
+        self.matchee = matchee
+        self.expected = expected
+
+    def describe(self):
+        return "'%s' does not start with '%s'." % (
+            self.matchee, self.expected)
+
+
+class DoesNotEndWith(Mismatch):
+
+    def __init__(self, matchee, expected):
+        """Create a DoesNotEndWith Mismatch.
+
+        :param matchee: the string that did not match.
+        :param expected: the string that `matchee` was expected to end with.
+        """
+        self.matchee = matchee
+        self.expected = expected
+
+    def describe(self):
+        return "'%s' does not end with '%s'." % (
+            self.matchee, self.expected)
+
+
 class _BinaryComparison(object):
     """Matcher that compares an object to another object."""
 
@@ -187,7 +229,6 @@ class _BinaryMismatch(Mismatch):
                 pformat(self.other))
         else:
             return "%s %s %s" % (left, self._mismatch_string,right)
-        return "%r %s %r" % (self.expected, self._mismatch_string, self.other)
 
 
 class Equals(_BinaryComparison):
@@ -305,6 +346,106 @@ class MatchedUnexpectedly(Mismatch):
         return "%r matches %s" % (self.other, self.matcher)
 
 
+class MatchesException(Matcher):
+    """Match an exc_info tuple against an exception instance or type."""
+
+    def __init__(self, exception):
+        """Create a MatchesException that will match exc_info's for exception.
+
+        :param exception: Either an exception instance or type.
+            If an instance is given, the type and arguments of the exception
+            are checked. If a type is given only the type of the exception is
+            checked.
+        """
+        Matcher.__init__(self)
+        self.expected = exception
+
+    def _expected_type(self):
+        if type(self.expected) is type:
+            return self.expected
+        return type(self.expected)
+
+    def match(self, other):
+        if type(other) != tuple:
+            return Mismatch('%r is not an exc_info tuple' % other)
+        if not issubclass(other[0], self._expected_type()):
+            return Mismatch('%r is not a %r' % (
+                other[0], self._expected_type()))
+        if (type(self.expected) is not type and
+            other[1].args != self.expected.args):
+            return Mismatch('%r has different arguments to %r.' % (
+                other[1], self.expected))
+
+    def __str__(self):
+        return "MatchesException(%r)" % self.expected
+
+
+class StartsWith(Matcher):
+    """Checks whether one string starts with another."""
+
+    def __init__(self, expected):
+        """Create a StartsWith Matcher.
+
+        :param expected: the string that matchees should start with.
+        """
+        self.expected = expected
+
+    def __str__(self):
+        return "Starts with '%s'." % self.expected
+
+    def match(self, matchee):
+        if not matchee.startswith(self.expected):
+            return DoesNotStartWith(matchee, self.expected)
+        return None
+
+
+class EndsWith(Matcher):
+    """Checks whether one string starts with another."""
+
+    def __init__(self, expected):
+        """Create a EndsWith Matcher.
+
+        :param expected: the string that matchees should end with.
+        """
+        self.expected = expected
+
+    def __str__(self):
+        return "Ends with '%s'." % self.expected
+
+    def match(self, matchee):
+        if not matchee.endswith(self.expected):
+            return DoesNotEndWith(matchee, self.expected)
+        return None
+
+
+class KeysEqual(Matcher):
+    """Checks whether a dict has particular keys."""
+
+    def __init__(self, *expected):
+        """Create a `KeysEqual` Matcher.
+
+        :param *expected: The keys the dict is expected to have.  If a dict,
+            then we use the keys of that dict, if a collection, we assume it
+            is a collection of expected keys.
+        """
+        try:
+            self.expected = expected.keys()
+        except AttributeError:
+            self.expected = list(expected)
+
+    def __str__(self):
+        return "KeysEqual(%s)" % ', '.join(map(repr, self.expected))
+
+    def match(self, matchee):
+        expected = sorted(self.expected)
+        matched = Equals(expected).match(sorted(matchee.keys()))
+        if matched:
+            return AnnotatedMismatch(
+                'Keys not equal',
+                _BinaryMismatch(expected, 'does not match', matchee))
+        return None
+
+
 class Annotate(object):
     """Annotates a matcher with a descriptive string.
 
@@ -333,3 +474,57 @@ class AnnotatedMismatch(Mismatch):
 
     def describe(self):
         return '%s: %s' % (self.mismatch.describe(), self.annotation)
+
+
+class Raises(Matcher):
+    """Match if the matchee raises an exception when called.
+
+    Exceptions which are not subclasses of Exception propogate out of the
+    Raises.match call unless they are explicitly matched.
+    """
+
+    def __init__(self, exception_matcher=None):
+        """Create a Raises matcher.
+
+        :param exception_matcher: Optional validator for the exception raised
+            by matchee. If supplied the exc_info tuple for the exception raised
+            is passed into that matcher. If no exception_matcher is supplied
+            then the simple fact of raising an exception is considered enough
+            to match on.
+        """
+        self.exception_matcher = exception_matcher
+
+    def match(self, matchee):
+        try:
+            result = matchee()
+            return Mismatch('%r returned %r' % (matchee, result))
+        # Catch all exceptions: Raises() should be able to match a
+        # KeyboardInterrupt or SystemExit.
+        except:
+            exc_info = sys.exc_info()
+            if self.exception_matcher:
+                mismatch = self.exception_matcher.match(sys.exc_info())
+                if not mismatch:
+                    return
+            else:
+                mismatch = None
+            # The exception did not match, or no explicit matching logic was
+            # performed. If the exception is a non-user exception (that is, not
+            # a subclass of Exception) then propogate it.
+            if not issubclass(exc_info[0], Exception):
+                raise exc_info[0], exc_info[1], exc_info[2]
+            return mismatch
+
+    def __str__(self):
+        return 'Raises()'
+
+
+def raises(exception):
+    """Make a matcher that checks that a callable raises an exception.
+
+    This is a convenience function, exactly equivalent to::
+        return Raises(MatchesException(exception))
+
+    See `Raises` and `MatchesException` for more information.
+    """
+    return Raises(MatchesException(exception))
index b6f2c491cd5abe3b66a63d92bb281e4e96370995..da4496a0c08b22dd9a8c3b5d0871b2980a895a58 100755 (executable)
@@ -14,6 +14,7 @@ import sys
 
 from testtools import TextTestResult
 from testtools.compat import classtypes, istext, unicode_output_stream
+from testtools.testsuite import iterate_tests
 
 
 defaultTestLoader = unittest.defaultTestLoader
@@ -34,9 +35,12 @@ else:
 class TestToolsTestRunner(object):
     """ A thunk object to support unittest.TestProgram."""
 
+    def __init__(self, stdout):
+        self.stdout = stdout
+
     def run(self, test):
         "Run the given test case or test suite."
-        result = TextTestResult(unicode_output_stream(sys.stdout))
+        result = TextTestResult(unicode_output_stream(self.stdout))
         result.startTestRun()
         try:
             return test.run(result)
@@ -58,6 +62,12 @@ class TestToolsTestRunner(object):
 #    removed.
 #  - A tweak has been added to detect 'python -m *.run' and use a
 #    better progName in that case.
+#  - self.module is more comprehensively set to None when being invoked from
+#    the commandline - __name__ is used as a sentinel value.
+#  - --list has been added which can list tests (should be upstreamed).
+#  - --load-list has been added which can reduce the tests used (should be
+#    upstreamed).
+#  - The limitation of using getopt is declared to the user.
 
 FAILFAST     = "  -f, --failfast   Stop on first failure\n"
 CATCHBREAK   = "  -c, --catch      Catch control-C and display results\n"
@@ -70,14 +80,17 @@ Options:
   -h, --help       Show this message
   -v, --verbose    Verbose output
   -q, --quiet      Minimal output
+  -l, --list       List tests rather than executing them.
+  --load-list      Specifies a file containing test ids, only tests matching
+                   those ids are executed.
 %(failfast)s%(catchbreak)s%(buffer)s
 Examples:
   %(progName)s test_module               - run tests from test_module
   %(progName)s module.TestClass          - run tests from module.TestClass
   %(progName)s module.Class.test_method  - run specified test method
 
-[tests] can be a list of any number of test modules, classes and test
-methods.
+All options must come before [tests].  [tests] can be a list of any number of
+test modules, classes and test methods.
 
 Alternative Usage: %(progName)s discover [options]
 
@@ -87,6 +100,9 @@ Options:
   -p pattern       Pattern to match test files ('test*.py' default)
   -t directory     Top level directory of project (default to
                    start directory)
+  -l, --list       List tests rather than executing them.
+  --load-list      Specifies a file containing test ids, only tests matching
+                   those ids are executed.
 
 For test discovery all test modules must be importable from the top
 level directory of the project.
@@ -102,11 +118,13 @@ class TestProgram(object):
     # defaults for testing
     failfast = catchbreak = buffer = progName = None
 
-    def __init__(self, module='__main__', defaultTest=None, argv=None,
+    def __init__(self, module=__name__, defaultTest=None, argv=None,
                     testRunner=None, testLoader=defaultTestLoader,
                     exit=True, verbosity=1, failfast=None, catchbreak=None,
-                    buffer=None):
-        if istext(module):
+                    buffer=None, stdout=None):
+        if module == __name__:
+            self.module = None
+        elif istext(module):
             self.module = __import__(module)
             for part in module.split('.')[1:]:
                 self.module = getattr(self.module, part)
@@ -121,6 +139,8 @@ class TestProgram(object):
         self.verbosity = verbosity
         self.buffer = buffer
         self.defaultTest = defaultTest
+        self.listtests = False
+        self.load_list = None
         self.testRunner = testRunner
         self.testLoader = testLoader
         progName = argv[0]
@@ -131,7 +151,27 @@ class TestProgram(object):
             progName = os.path.basename(argv[0])
         self.progName = progName
         self.parseArgs(argv)
-        self.runTests()
+        if self.load_list:
+            # TODO: preserve existing suites (like testresources does in
+            # OptimisingTestSuite.add, but with a standard protocol).
+            # This is needed because the load_tests hook allows arbitrary
+            # suites, even if that is rarely used.
+            source = file(self.load_list, 'rb')
+            try:
+                lines = source.readlines()
+            finally:
+                source.close()
+            test_ids = set(line.strip() for line in lines)
+            filtered = unittest.TestSuite()
+            for test in iterate_tests(self.test):
+                if test.id() in test_ids:
+                    filtered.addTest(test)
+            self.test = filtered
+        if not self.listtests:
+            self.runTests()
+        else:
+            for test in iterate_tests(self.test):
+                stdout.write('%s\n' % test.id())
 
     def usageExit(self, msg=None):
         if msg:
@@ -153,9 +193,10 @@ class TestProgram(object):
             return
 
         import getopt
-        long_opts = ['help', 'verbose', 'quiet', 'failfast', 'catch', 'buffer']
+        long_opts = ['help', 'verbose', 'quiet', 'failfast', 'catch', 'buffer',
+            'list', 'load-list=']
         try:
-            options, args = getopt.getopt(argv[1:], 'hHvqfcb', long_opts)
+            options, args = getopt.getopt(argv[1:], 'hHvqfcbl', long_opts)
             for opt, value in options:
                 if opt in ('-h','-H','--help'):
                     self.usageExit()
@@ -175,21 +216,20 @@ class TestProgram(object):
                     if self.buffer is None:
                         self.buffer = True
                     # Should this raise an exception if -b is not valid?
+                if opt in ('-l', '--list'):
+                    self.listtests = True
+                if opt == '--load-list':
+                    self.load_list = value
             if len(args) == 0 and self.defaultTest is None:
                 # createTests will load tests from self.module
                 self.testNames = None
             elif len(args) > 0:
                 self.testNames = args
-                if __name__ == '__main__':
-                    # to support python -m unittest ...
-                    self.module = None
             else:
                 self.testNames = (self.defaultTest,)
             self.createTests()
         except getopt.error:
-            exc_info = sys.exc_info()
-            msg = exc_info[1]
-            self.usageExit(msg)
+            self.usageExit(sys.exc_info()[1])
 
     def createTests(self):
         if self.testNames is None:
@@ -227,6 +267,10 @@ class TestProgram(object):
                           help="Pattern to match tests ('test*.py' default)")
         parser.add_option('-t', '--top-level-directory', dest='top', default=None,
                           help='Top level directory of project (defaults to start directory)')
+        parser.add_option('-l', '--list', dest='listtests', default=False,
+                          help='List tests rather than running them.')
+        parser.add_option('--load-list', dest='load_list', default=None,
+                          help='Specify a filename containing the test ids to use.')
 
         options, args = parser.parse_args(argv)
         if len(args) > 3:
@@ -243,6 +287,8 @@ class TestProgram(object):
             self.catchbreak = options.catchbreak
         if self.buffer is None:
             self.buffer = options.buffer
+        self.listtests = options.listtests
+        self.load_list = options.load_list
 
         if options.verbose:
             self.verbosity = 2
@@ -276,7 +322,9 @@ class TestProgram(object):
             sys.exit(not self.result.wasSuccessful())
 ################
 
+def main(argv, stdout):
+    runner = TestToolsTestRunner(stdout)
+    program = TestProgram(argv=argv, testRunner=runner, stdout=stdout)
 
 if __name__ == '__main__':
-    runner = TestToolsTestRunner()
-    program = TestProgram(argv=sys.argv, testRunner=runner)
+    main(sys.argv, sys.stdout)
index 34954935acf24cdef9ec8d0a126ba4f6d81dea7a..eb5801a4c643001a8738348ae3e558cb3c0bce74 100644 (file)
@@ -1,9 +1,9 @@
-# Copyright (c) 2009 Jonathan M. Lange. See LICENSE for details.
+# Copyright (c) 2009-2010 Jonathan M. Lange. See LICENSE for details.
 
 """Individual test case execution."""
 
-__metaclass__ = type
 __all__ = [
+    'MultipleExceptions',
     'RunTest',
     ]
 
@@ -12,6 +12,13 @@ import sys
 from testtools.testresult import ExtendedToOriginalDecorator
 
 
+class MultipleExceptions(Exception):
+    """Represents many exceptions raised from some operation.
+
+    :ivar args: The sys.exc_info() tuples for each exception.
+    """
+
+
 class RunTest(object):
     """An object to run a test.
 
@@ -25,15 +32,15 @@ class RunTest(object):
 
     :ivar case: The test case that is to be run.
     :ivar result: The result object a case is reporting to.
-    :ivar handlers: A list of (ExceptionClass->handler code) for exceptions
-        that should be caught if raised from the user code. Exceptions that
-        are caught are checked against this list in first to last order.
-        There is a catchall of Exception at the end of the list, so to add
-        a new exception to the list, insert it at the front (which ensures that
-        it will be checked before any existing base classes in the list. If you
-        add multiple exceptions some of which are subclasses of each other, add
-        the most specific exceptions last (so they come before their parent
-        classes in the list).
+    :ivar handlers: A list of (ExceptionClass, handler_function) for
+        exceptions that should be caught if raised from the user
+        code. Exceptions that are caught are checked against this list in
+        first to last order.  There is a catch-all of `Exception` at the end
+        of the list, so to add a new exception to the list, insert it at the
+        front (which ensures that it will be checked before any existing base
+        classes in the list. If you add multiple exceptions some of which are
+        subclasses of each other, add the most specific exceptions last (so
+        they come before their parent classes in the list).
     :ivar exception_caught: An object returned when _run_user catches an
         exception.
     :ivar _exceptions: A list of caught exceptions, used to do the single
@@ -108,9 +115,7 @@ class RunTest(object):
         if self.exception_caught == self._run_user(self.case._run_setup,
             self.result):
             # Don't run the test method if we failed getting here.
-            e = self.case._runCleanups(self.result)
-            if e is not None:
-                self._exceptions.append(e)
+            self._run_cleanups(self.result)
             return
         # Run everything from here on in. If any of the methods raise an
         # exception we'll have failed.
@@ -126,30 +131,70 @@ class RunTest(object):
                     failed = True
             finally:
                 try:
-                    e = self._run_user(self.case._runCleanups, self.result)
-                    if e is not None:
-                        self._exceptions.append(e)
+                    if self.exception_caught == self._run_user(
+                        self._run_cleanups, self.result):
                         failed = True
                 finally:
                     if not failed:
                         self.result.addSuccess(self.case,
                             details=self.case.getDetails())
 
-    def _run_user(self, fn, *args):
+    def _run_cleanups(self, result):
+        """Run the cleanups that have been added with addCleanup.
+
+        See the docstring for addCleanup for more information.
+
+        :return: None if all cleanups ran without error,
+            `self.exception_caught` if there was an error.
+        """
+        failing = False
+        while self.case._cleanups:
+            function, arguments, keywordArguments = self.case._cleanups.pop()
+            got_exception = self._run_user(
+                function, *arguments, **keywordArguments)
+            if got_exception == self.exception_caught:
+                failing = True
+        if failing:
+            return self.exception_caught
+
+    def _run_user(self, fn, *args, **kwargs):
         """Run a user supplied function.
 
-        Exceptions are processed by self.handlers.
+        Exceptions are processed by `_got_user_exception`.
+
+        :return: Either whatever 'fn' returns or `self.exception_caught` if
+            'fn' raised an exception.
         """
         try:
-            return fn(*args)
+            return fn(*args, **kwargs)
         except KeyboardInterrupt:
             raise
         except:
-            exc_info = sys.exc_info()
+            return self._got_user_exception(sys.exc_info())
+
+    def _got_user_exception(self, exc_info, tb_label='traceback'):
+        """Called when user code raises an exception.
+
+        If 'exc_info' is a `MultipleExceptions`, then we recurse into it
+        unpacking the errors that it's made up from.
+
+        :param exc_info: A sys.exc_info() tuple for the user error.
+        :param tb_label: An optional string label for the error.  If
+            not specified, will default to 'traceback'.
+        :return: `exception_caught` if we catch one of the exceptions that
+            have handlers in `self.handlers`, otherwise raise the error.
+        """
+        if exc_info[0] is MultipleExceptions:
+            for sub_exc_info in exc_info[1].args:
+                self._got_user_exception(sub_exc_info, tb_label)
+            return self.exception_caught
+        try:
             e = exc_info[1]
-            self.case.onException(exc_info)
-            for exc_class, handler in self.handlers:
-                if isinstance(e, exc_class):
-                    self._exceptions.append(e)
-                    return self.exception_caught
-            raise e
+            self.case.onException(exc_info, tb_label=tb_label)
+        finally:
+            del exc_info
+        for exc_class, handler in self.handlers:
+            if isinstance(e, exc_class):
+                self._exceptions.append(e)
+                return self.exception_caught
+        raise e
index 573cd84dc2f4127a23c72ece37cc535d87557a8b..ba7b480355c70e27af27464e029b05d15f9b419f 100644 (file)
@@ -5,24 +5,23 @@
 __metaclass__ = type
 __all__ = [
     'clone_test_with_new_id',
-    'MultipleExceptions',
-    'TestCase',
+    'run_test_with',
     'skip',
     'skipIf',
     'skipUnless',
+    'TestCase',
     ]
 
 import copy
-try:
-    from functools import wraps
-except ImportError:
-    wraps = None
 import itertools
 import sys
 import types
 import unittest
 
-from testtools import content
+from testtools import (
+    content,
+    try_import,
+    )
 from testtools.compat import advance_iterator
 from testtools.matchers import (
     Annotate,
@@ -32,40 +31,64 @@ from testtools.monkey import patch
 from testtools.runtest import RunTest
 from testtools.testresult import TestResult
 
+wraps = try_import('functools.wraps')
 
-try:
-    # Try to use the python2.7 SkipTest exception for signalling skips.
-    from unittest.case import SkipTest as TestSkipped
-except ImportError:
-    class TestSkipped(Exception):
-        """Raised within TestCase.run() when a test is skipped."""
+class TestSkipped(Exception):
+    """Raised within TestCase.run() when a test is skipped."""
+TestSkipped = try_import('unittest.case.SkipTest', TestSkipped)
 
 
-try:
-    # Try to use the same exceptions python 2.7 does.
-    from unittest.case import _ExpectedFailure, _UnexpectedSuccess
-except ImportError:
-    # Oops, not available, make our own.
-    class _UnexpectedSuccess(Exception):
-        """An unexpected success was raised.
-
-        Note that this exception is private plumbing in testtools' testcase
-        module.
-        """
-
-    class _ExpectedFailure(Exception):
-        """An expected failure occured.
-
-        Note that this exception is private plumbing in testtools' testcase
-        module.
-        """
+class _UnexpectedSuccess(Exception):
+    """An unexpected success was raised.
 
+    Note that this exception is private plumbing in testtools' testcase
+    module.
+    """
+_UnexpectedSuccess = try_import(
+    'unittest.case._UnexpectedSuccess', _UnexpectedSuccess)
 
-class MultipleExceptions(Exception):
-    """Represents many exceptions raised from some operation.
+class _ExpectedFailure(Exception):
+    """An expected failure occured.
 
-    :ivar args: The sys.exc_info() tuples for each exception.
+    Note that this exception is private plumbing in testtools' testcase
+    module.
     """
+_ExpectedFailure = try_import(
+    'unittest.case._ExpectedFailure', _ExpectedFailure)
+
+
+def run_test_with(test_runner, **kwargs):
+    """Decorate a test as using a specific `RunTest`.
+
+    e.g.
+      @run_test_with(CustomRunner, timeout=42)
+      def test_foo(self):
+          self.assertTrue(True)
+
+    The returned decorator works by setting an attribute on the decorated
+    function.  `TestCase.__init__` looks for this attribute when deciding
+    on a `RunTest` factory.  If you wish to use multiple decorators on a test
+    method, then you must either make this one the top-most decorator, or
+    you must write your decorators so that they update the wrapping function
+    with the attributes of the wrapped function.  The latter is recommended
+    style anyway.  `functools.wraps`, `functools.wrapper` and
+    `twisted.python.util.mergeFunctionMetadata` can help you do this.
+
+    :param test_runner: A `RunTest` factory that takes a test case and an
+        optional list of exception handlers.  See `RunTest`.
+    :param **kwargs: Keyword arguments to pass on as extra arguments to
+        `test_runner`.
+    :return: A decorator to be used for marking a test as needing a special
+        runner.
+    """
+    def decorator(function):
+        # Set an attribute on 'function' which will inform TestCase how to
+        # make the runner.
+        function._run_test_with = (
+            lambda case, handlers=None:
+                test_runner(case, handlers=handlers, **kwargs))
+        return function
+    return decorator
 
 
 class TestCase(unittest.TestCase):
@@ -74,28 +97,41 @@ class TestCase(unittest.TestCase):
     :ivar exception_handlers: Exceptions to catch from setUp, runTest and
         tearDown. This list is able to be modified at any time and consists of
         (exception_class, handler(case, result, exception_value)) pairs.
+    :cvar run_tests_with: A factory to make the `RunTest` to run tests with.
+        Defaults to `RunTest`.  The factory is expected to take a test case
+        and an optional list of exception handlers.
     """
 
     skipException = TestSkipped
 
+    run_tests_with = RunTest
+
     def __init__(self, *args, **kwargs):
         """Construct a TestCase.
 
         :param testMethod: The name of the method to run.
         :param runTest: Optional class to use to execute the test. If not
-            supplied testtools.runtest.RunTest is used. The instance to be
+            supplied `testtools.runtest.RunTest` is used. The instance to be
             used is created when run() is invoked, so will be fresh each time.
+            Overrides `run_tests_with` if given.
         """
+        runTest = kwargs.pop('runTest', None)
         unittest.TestCase.__init__(self, *args, **kwargs)
         self._cleanups = []
         self._unique_id_gen = itertools.count(1)
-        self._traceback_id_gen = itertools.count(0)
+        # Generators to ensure unique traceback ids.  Maps traceback label to
+        # iterators.
+        self._traceback_id_gens = {}
         self.__setup_called = False
         self.__teardown_called = False
         # __details is lazy-initialized so that a constructed-but-not-run
         # TestCase is safe to use with clone_test_with_new_id.
         self.__details = None
-        self.__RunTest = kwargs.get('runTest', RunTest)
+        test_method = self._get_test_method()
+        if runTest is None:
+            runTest = getattr(
+                test_method, '_run_test_with', self.run_tests_with)
+        self.__RunTest = runTest
         self.__exception_handlers = []
         self.exception_handlers = [
             (self.skipException, self._report_skip),
@@ -180,32 +216,6 @@ class TestCase(unittest.TestCase):
             className = ', '.join(klass.__name__ for klass in classOrIterable)
         return className
 
-    def _runCleanups(self, result):
-        """Run the cleanups that have been added with addCleanup.
-
-        See the docstring for addCleanup for more information.
-
-        :return: None if all cleanups ran without error, the most recently
-            raised exception from the cleanups otherwise.
-        """
-        last_exception = None
-        while self._cleanups:
-            function, arguments, keywordArguments = self._cleanups.pop()
-            try:
-                function(*arguments, **keywordArguments)
-            except KeyboardInterrupt:
-                raise
-            except:
-                exceptions = [sys.exc_info()]
-                while exceptions:
-                    exc_info = exceptions.pop()
-                    if exc_info[0] is MultipleExceptions:
-                        exceptions.extend(exc_info[1].args)
-                        continue
-                    self._report_traceback(exc_info)
-                    last_exception = exc_info[1]
-        return last_exception
-
     def addCleanup(self, function, *arguments, **keywordArguments):
         """Add a cleanup function to be called after tearDown.
 
@@ -356,9 +366,14 @@ class TestCase(unittest.TestCase):
         try:
             predicate(*args, **kwargs)
         except self.failureException:
+            # GZ 2010-08-12: Don't know how to avoid exc_info cycle as the new
+            #                unittest _ExpectedFailure wants old traceback
             exc_info = sys.exc_info()
-            self._report_traceback(exc_info)
-            raise _ExpectedFailure(exc_info)
+            try:
+                self._report_traceback(exc_info)
+                raise _ExpectedFailure(exc_info)
+            finally:
+                del exc_info
         else:
             raise _UnexpectedSuccess(reason)
 
@@ -386,14 +401,14 @@ class TestCase(unittest.TestCase):
             prefix = self.id()
         return '%s-%d' % (prefix, self.getUniqueInteger())
 
-    def onException(self, exc_info):
+    def onException(self, exc_info, tb_label='traceback'):
         """Called when an exception propogates from test code.
 
         :seealso addOnException:
         """
         if exc_info[0] not in [
             TestSkipped, _UnexpectedSuccess, _ExpectedFailure]:
-            self._report_traceback(exc_info)
+            self._report_traceback(exc_info, tb_label=tb_label)
         for handler in self.__exception_handlers:
             handler(exc_info)
 
@@ -418,12 +433,12 @@ class TestCase(unittest.TestCase):
         self._add_reason(reason)
         result.addSkip(self, details=self.getDetails())
 
-    def _report_traceback(self, exc_info):
-        tb_id = advance_iterator(self._traceback_id_gen)
+    def _report_traceback(self, exc_info, tb_label='traceback'):
+        id_gen = self._traceback_id_gens.setdefault(
+            tb_label, itertools.count(0))
+        tb_id = advance_iterator(id_gen)
         if tb_id:
-            tb_label = 'traceback-%d' % tb_id
-        else:
-            tb_label = 'traceback'
+            tb_label = '%s-%d' % (tb_label, tb_id)
         self.addDetail(tb_label, content.TracebackContent(exc_info, self))
 
     @staticmethod
@@ -440,13 +455,14 @@ class TestCase(unittest.TestCase):
         :raises ValueError: If the base class setUp is not called, a
             ValueError is raised.
         """
-        self.setUp()
+        ret = self.setUp()
         if not self.__setup_called:
             raise ValueError(
                 "TestCase.setUp was not called. Have you upcalled all the "
                 "way up the hierarchy from your setUp? e.g. Call "
                 "super(%s, self).setUp() from your setUp()."
                 % self.__class__.__name__)
+        return ret
 
     def _run_teardown(self, result):
         """Run the tearDown function for this test.
@@ -455,28 +471,60 @@ class TestCase(unittest.TestCase):
         :raises ValueError: If the base class tearDown is not called, a
             ValueError is raised.
         """
-        self.tearDown()
+        ret = self.tearDown()
         if not self.__teardown_called:
             raise ValueError(
                 "TestCase.tearDown was not called. Have you upcalled all the "
                 "way up the hierarchy from your tearDown? e.g. Call "
                 "super(%s, self).tearDown() from your tearDown()."
                 % self.__class__.__name__)
+        return ret
 
-    def _run_test_method(self, result):
-        """Run the test method for this test.
-
-        :param result: A testtools.TestResult to report activity to.
-        :return: None.
-        """
+    def _get_test_method(self):
         absent_attr = object()
         # Python 2.5+
         method_name = getattr(self, '_testMethodName', absent_attr)
         if method_name is absent_attr:
             # Python 2.4
             method_name = getattr(self, '_TestCase__testMethodName')
-        testMethod = getattr(self, method_name)
-        testMethod()
+        return getattr(self, method_name)
+
+    def _run_test_method(self, result):
+        """Run the test method for this test.
+
+        :param result: A testtools.TestResult to report activity to.
+        :return: None.
+        """
+        return self._get_test_method()()
+
+    def useFixture(self, fixture):
+        """Use fixture in a test case.
+
+        The fixture will be setUp, and self.addCleanup(fixture.cleanUp) called.
+
+        :param fixture: The fixture to use.
+        :return: The fixture, after setting it up and scheduling a cleanup for
+           it.
+        """
+        fixture.setUp()
+        self.addCleanup(fixture.cleanUp)
+        self.addCleanup(self._gather_details, fixture.getDetails)
+        return fixture
+
+    def _gather_details(self, getDetails):
+        """Merge the details from getDetails() into self.getDetails()."""
+        details = getDetails()
+        my_details = self.getDetails()
+        for name, content_object in details.items():
+            new_name = name
+            disambiguator = itertools.count(1)
+            while new_name in my_details:
+                new_name = '%s-%d' % (name, advance_iterator(disambiguator))
+            name = new_name
+            content_bytes = list(content_object.iter_bytes())
+            content_callback = lambda:content_bytes
+            self.addDetail(name,
+                content.Content(content_object.content_type, content_callback))
 
     def setUp(self):
         unittest.TestCase.setUp(self)
index d231c919c24b69631993ee3ab90f2762b30f5442..7e4a2c9b4126ad3cc229e65546d973ad1edd0afc 100644 (file)
@@ -1,4 +1,4 @@
-# Copyright (c) 2009 Jonathan M. Lange. See LICENSE for details.
+# Copyright (c) 2009-2010 Jonathan M. Lange. See LICENSE for details.
 
 """Doubles of test result objects, useful for testing unittest code."""
 
@@ -15,15 +15,18 @@ class LoggingBase(object):
     def __init__(self):
         self._events = []
         self.shouldStop = False
+        self._was_successful = True
 
 
 class Python26TestResult(LoggingBase):
     """A precisely python 2.6 like test result, that logs."""
 
     def addError(self, test, err):
+        self._was_successful = False
         self._events.append(('addError', test, err))
 
     def addFailure(self, test, err):
+        self._was_successful = False
         self._events.append(('addFailure', test, err))
 
     def addSuccess(self, test):
@@ -38,6 +41,9 @@ class Python26TestResult(LoggingBase):
     def stopTest(self, test):
         self._events.append(('stopTest', test))
 
+    def wasSuccessful(self):
+        return self._was_successful
+
 
 class Python27TestResult(Python26TestResult):
     """A precisely python 2.7 like test result, that logs."""
@@ -62,9 +68,11 @@ class ExtendedTestResult(Python27TestResult):
     """A test result like the proposed extended unittest result API."""
 
     def addError(self, test, err=None, details=None):
+        self._was_successful = False
         self._events.append(('addError', test, err or details))
 
     def addFailure(self, test, err=None, details=None):
+        self._was_successful = False
         self._events.append(('addFailure', test, err or details))
 
     def addExpectedFailure(self, test, err=None, details=None):
@@ -80,6 +88,7 @@ class ExtendedTestResult(Python27TestResult):
             self._events.append(('addSuccess', test))
 
     def addUnexpectedSuccess(self, test, details=None):
+        self._was_successful = False
         if details is not None:
             self._events.append(('addUnexpectedSuccess', test, details))
         else:
@@ -88,8 +97,15 @@ class ExtendedTestResult(Python27TestResult):
     def progress(self, offset, whence):
         self._events.append(('progress', offset, whence))
 
+    def startTestRun(self):
+        super(ExtendedTestResult, self).startTestRun()
+        self._was_successful = True
+
     def tags(self, new_tags, gone_tags):
         self._events.append(('tags', new_tags, gone_tags))
 
     def time(self, time):
         self._events.append(('time', time))
+
+    def wasSuccessful(self):
+        return self._was_successful
index 95f6e8f04c17ae618cb15734be11360655b51dc1..d1a10236452f42b25e8e71cdc5c9fc15f5dea766 100644 (file)
@@ -11,6 +11,7 @@ __all__ = [
     ]
 
 import datetime
+import sys
 import unittest
 
 from testtools.compat import _format_exc_info, str_is_unicode, _u
@@ -34,13 +35,11 @@ class TestResult(unittest.TestResult):
     """
 
     def __init__(self):
-        super(TestResult, self).__init__()
-        self.skip_reasons = {}
-        self.__now = None
-        # -- Start: As per python 2.7 --
-        self.expectedFailures = []
-        self.unexpectedSuccesses = []
-        # -- End:   As per python 2.7 --
+        # startTestRun resets all attributes, and older clients don't know to
+        # call startTestRun, so it is called once here.
+        # Because subclasses may reasonably not expect this, we call the 
+        # specific version we want to run.
+        TestResult.startTestRun(self)
 
     def addExpectedFailure(self, test, err=None, details=None):
         """Called when a test has failed in an expected manner.
@@ -107,6 +106,18 @@ class TestResult(unittest.TestResult):
         """Called when a test was expected to fail, but succeed."""
         self.unexpectedSuccesses.append(test)
 
+    def wasSuccessful(self):
+        """Has this result been successful so far?
+
+        If there have been any errors, failures or unexpected successes,
+        return False.  Otherwise, return True.
+
+        Note: This differs from standard unittest in that we consider
+        unexpected successes to be equivalent to failures, rather than
+        successes.
+        """
+        return not (self.errors or self.failures or self.unexpectedSuccesses)
+
     if str_is_unicode:
         # Python 3 and IronPython strings are unicode, use parent class method
         _exc_info_to_unicode = unittest.TestResult._exc_info_to_string
@@ -145,8 +156,16 @@ class TestResult(unittest.TestResult):
     def startTestRun(self):
         """Called before a test run starts.
 
-        New in python 2.7
+        New in python 2.7. The testtools version resets the result to a
+        pristine condition ready for use in another test run.
         """
+        super(TestResult, self).__init__()
+        self.skip_reasons = {}
+        self.__now = None
+        # -- Start: As per python 2.7 --
+        self.expectedFailures = []
+        self.unexpectedSuccesses = []
+        # -- End:   As per python 2.7 --
 
     def stopTestRun(self):
         """Called after a test run completes
@@ -181,7 +200,7 @@ class MultiTestResult(TestResult):
 
     def __init__(self, *results):
         TestResult.__init__(self)
-        self._results = map(ExtendedToOriginalDecorator, results)
+        self._results = list(map(ExtendedToOriginalDecorator, results))
 
     def _dispatch(self, message, *args, **kwargs):
         return tuple(
@@ -222,6 +241,13 @@ class MultiTestResult(TestResult):
     def done(self):
         return self._dispatch('done')
 
+    def wasSuccessful(self):
+        """Was this result successful?
+
+        Only returns True if every constituent result was successful.
+        """
+        return all(self._dispatch('wasSuccessful'))
+
 
 class TextTestResult(TestResult):
     """A TestResult which outputs activity to a text stream."""
@@ -257,6 +283,10 @@ class TextTestResult(TestResult):
         stop = self._now()
         self._show_list('ERROR', self.errors)
         self._show_list('FAIL', self.failures)
+        for test in self.unexpectedSuccesses:
+            self.stream.write(
+                "%sUNEXPECTED SUCCESS: %s\n%s" % (
+                    self.sep1, test.id(), self.sep2))
         self.stream.write("Ran %d test%s in %.3fs\n\n" %
             (self.testsRun, plural,
              self._delta_to_float(stop - self.__start)))
@@ -266,7 +296,8 @@ class TextTestResult(TestResult):
             self.stream.write("FAILED (")
             details = []
             details.append("failures=%d" % (
-                len(self.failures) + len(self.errors)))
+                sum(map(len, (
+                    self.failures, self.errors, self.unexpectedSuccesses)))))
             self.stream.write(", ".join(details))
             self.stream.write(")\n")
         super(TextTestResult, self).stopTestRun()
@@ -300,59 +331,42 @@ class ThreadsafeForwardingResult(TestResult):
         self.result = ExtendedToOriginalDecorator(target)
         self.semaphore = semaphore
 
-    def addError(self, test, err=None, details=None):
+    def _add_result_with_semaphore(self, method, test, *args, **kwargs):
         self.semaphore.acquire()
         try:
+            self.result.time(self._test_start)
             self.result.startTest(test)
-            self.result.addError(test, err, details=details)
-            self.result.stopTest(test)
+            self.result.time(self._now())
+            try:
+                method(test, *args, **kwargs)
+            finally:
+                self.result.stopTest(test)
         finally:
             self.semaphore.release()
 
+    def addError(self, test, err=None, details=None):
+        self._add_result_with_semaphore(self.result.addError,
+            test, err, details=details)
+
     def addExpectedFailure(self, test, err=None, details=None):
-        self.semaphore.acquire()
-        try:
-            self.result.startTest(test)
-            self.result.addExpectedFailure(test, err, details=details)
-            self.result.stopTest(test)
-        finally:
-            self.semaphore.release()
+        self._add_result_with_semaphore(self.result.addExpectedFailure,
+            test, err, details=details)
 
     def addFailure(self, test, err=None, details=None):
-        self.semaphore.acquire()
-        try:
-            self.result.startTest(test)
-            self.result.addFailure(test, err, details=details)
-            self.result.stopTest(test)
-        finally:
-            self.semaphore.release()
+        self._add_result_with_semaphore(self.result.addFailure,
+            test, err, details=details)
 
     def addSkip(self, test, reason=None, details=None):
-        self.semaphore.acquire()
-        try:
-            self.result.startTest(test)
-            self.result.addSkip(test, reason, details=details)
-            self.result.stopTest(test)
-        finally:
-            self.semaphore.release()
+        self._add_result_with_semaphore(self.result.addSkip,
+            test, reason, details=details)
 
     def addSuccess(self, test, details=None):
-        self.semaphore.acquire()
-        try:
-            self.result.startTest(test)
-            self.result.addSuccess(test, details=details)
-            self.result.stopTest(test)
-        finally:
-            self.semaphore.release()
+        self._add_result_with_semaphore(self.result.addSuccess,
+            test, details=details)
 
     def addUnexpectedSuccess(self, test, details=None):
-        self.semaphore.acquire()
-        try:
-            self.result.startTest(test)
-            self.result.addUnexpectedSuccess(test, details=details)
-            self.result.stopTest(test)
-        finally:
-            self.semaphore.release()
+        self._add_result_with_semaphore(self.result.addUnexpectedSuccess,
+            test, details=details)
 
     def startTestRun(self):
         self.semaphore.acquire()
@@ -375,6 +389,13 @@ class ThreadsafeForwardingResult(TestResult):
         finally:
             self.semaphore.release()
 
+    def startTest(self, test):
+        self._test_start = self._now()
+        super(ThreadsafeForwardingResult, self).startTest(test)
+
+    def wasSuccessful(self):
+        return self.result.wasSuccessful()
+
 
 class ExtendedToOriginalDecorator(object):
     """Permit new TestResult API code to degrade gracefully with old results.
@@ -435,14 +456,20 @@ class ExtendedToOriginalDecorator(object):
             try:
                 return addSkip(test, details=details)
             except TypeError:
-                # have to convert
-                reason = _details_to_str(details)
+                # extract the reason if it's available
+                try:
+                    reason = ''.join(details['reason'].iter_text())
+                except KeyError:
+                    reason = _details_to_str(details)
         return addSkip(test, reason)
 
     def addUnexpectedSuccess(self, test, details=None):
         outcome = getattr(self.decorated, 'addUnexpectedSuccess', None)
         if outcome is None:
-            return self.decorated.addSuccess(test)
+            try:
+                test.fail("")
+            except test.failureException:
+                return self.addFailure(test, sys.exc_info())
         if details is not None:
             try:
                 return outcome(test, details=details)
index 5e22000bb46d714ce60ff5811a967f2f923a7fbb..ac3c218de9766f376f2ea585d1b1d3148f5f1cc5 100644 (file)
@@ -3,32 +3,39 @@
 # See README for copyright and licensing details.
 
 import unittest
-from testtools.tests import (
-    test_compat,
-    test_content,
-    test_content_type,
-    test_matchers,
-    test_monkey,
-    test_runtest,
-    test_testtools,
-    test_testresult,
-    test_testsuite,
-    )
 
 
 def test_suite():
-    suites = []
-    modules = [
+    from testtools.tests import (
         test_compat,
         test_content,
         test_content_type,
+        test_deferredruntest,
+        test_fixturesupport,
+        test_helpers,
         test_matchers,
         test_monkey,
+        test_run,
         test_runtest,
+        test_spinner,
+        test_testtools,
+        test_testresult,
+        test_testsuite,
+        )
+    modules = [
+        test_compat,
+        test_content,
+        test_content_type,
+        test_deferredruntest,
+        test_fixturesupport,
+        test_helpers,
+        test_matchers,
+        test_monkey,
+        test_run,
+        test_spinner,
         test_testresult,
         test_testsuite,
         test_testtools,
         ]
-    for module in modules:
-        suites.append(getattr(module, 'test_suite')())
+    suites = map(lambda x:x.test_suite(), modules)
     return unittest.TestSuite(suites)
index c4cf10c73651bf4f87d426f9c9f3199a08324d14..5f3187db296c4ab1b2c44959ec2871a9a2676701 100644 (file)
@@ -12,6 +12,7 @@ __all__ = [
 from testtools import TestResult
 
 
+# GZ 2010-08-12: Don't do this, pointlessly creates an exc_info cycle
 try:
     raise Exception
 except Exception:
@@ -62,6 +63,10 @@ class LoggingResult(TestResult):
         self._events.append('done')
         super(LoggingResult, self).done()
 
+    def time(self, a_datetime):
+        self._events.append(('time', a_datetime))
+        super(LoggingResult, self).time(a_datetime)
+
 # Note, the following three classes are different to LoggingResult by
 # being fully defined exact matches rather than supersets.
 from testtools.testresult.doubles import *
index 138b286d5db71b014d04a2ed2437badcd5440901..856953896a9d6386f7f82805c45eb4e40a5a5e99 100644 (file)
@@ -17,6 +17,11 @@ from testtools.compat import (
     _u,
     unicode_output_stream,
     )
+from testtools.matchers import (
+    MatchesException,
+    Not,
+    Raises,
+    )
 
 
 class TestDetectEncoding(testtools.TestCase):
@@ -192,34 +197,34 @@ class TestUnicodeOutputStream(testtools.TestCase):
         super(TestUnicodeOutputStream, self).setUp()
         if sys.platform == "cli":
             self.skip("IronPython shouldn't wrap streams to do encoding")
-    
+
     def test_no_encoding_becomes_ascii(self):
         """A stream with no encoding attribute gets ascii/replace strings"""
         sout = _FakeOutputStream()
         unicode_output_stream(sout).write(self.uni)
         self.assertEqual([_b("pa???n")], sout.writelog)
-    
+
     def test_encoding_as_none_becomes_ascii(self):
         """A stream with encoding value of None gets ascii/replace strings"""
         sout = _FakeOutputStream()
         sout.encoding = None
         unicode_output_stream(sout).write(self.uni)
         self.assertEqual([_b("pa???n")], sout.writelog)
-    
+
     def test_bogus_encoding_becomes_ascii(self):
         """A stream with a bogus encoding gets ascii/replace strings"""
         sout = _FakeOutputStream()
         sout.encoding = "bogus"
         unicode_output_stream(sout).write(self.uni)
         self.assertEqual([_b("pa???n")], sout.writelog)
-    
+
     def test_partial_encoding_replace(self):
         """A string which can be partly encoded correctly should be"""
         sout = _FakeOutputStream()
         sout.encoding = "iso-8859-7"
         unicode_output_stream(sout).write(self.uni)
         self.assertEqual([_b("pa?\xe8?n")], sout.writelog)
-    
+
     def test_unicode_encodings_not_wrapped(self):
         """A unicode encoding is left unwrapped as needs no error handler"""
         sout = _FakeOutputStream()
@@ -228,7 +233,7 @@ class TestUnicodeOutputStream(testtools.TestCase):
         sout = _FakeOutputStream()
         sout.encoding = "utf-16-be"
         self.assertIs(unicode_output_stream(sout), sout)
-    
+
     def test_stringio(self):
         """A StringIO object should maybe get an ascii native str type"""
         try:
@@ -241,7 +246,8 @@ class TestUnicodeOutputStream(testtools.TestCase):
         soutwrapper = unicode_output_stream(sout)
         if newio:
             self.expectFailure("Python 3 StringIO expects text not bytes",
-                self.assertRaises, TypeError, soutwrapper.write, self.uni)
+                self.assertThat, lambda: soutwrapper.write(self.uni),
+                Not(Raises(MatchesException(TypeError))))
         soutwrapper.write(self.uni)
         self.assertEqual("pa???n", sout.getvalue())
 
index 741256ef7a911b5e2b6bdbc155e93c4c11aaf0a7..eaf50c7f373e41cdfe7d2e7bbfa1346c8b1224fc 100644 (file)
@@ -2,19 +2,24 @@
 
 import unittest
 from testtools import TestCase
-from testtools.compat import _u
-from testtools.content import Content, TracebackContent
-from testtools.content_type import ContentType
+from testtools.compat import _b, _u
+from testtools.content import Content, TracebackContent, text_content
+from testtools.content_type import ContentType, UTF8_TEXT
+from testtools.matchers import MatchesException, Raises
 from testtools.tests.helpers import an_exc_info
 
 
+raises_value_error = Raises(MatchesException(ValueError))
+
+
 class TestContent(TestCase):
 
     def test___init___None_errors(self):
-        self.assertRaises(ValueError, Content, None, None)
-        self.assertRaises(ValueError, Content, None, lambda: ["traceback"])
-        self.assertRaises(ValueError, Content,
-            ContentType("text", "traceback"), None)
+        self.assertThat(lambda:Content(None, None), raises_value_error)
+        self.assertThat(lambda:Content(None, lambda: ["traceback"]),
+            raises_value_error)
+        self.assertThat(lambda:Content(ContentType("text", "traceback"), None),
+            raises_value_error)
 
     def test___init___sets_ivars(self):
         content_type = ContentType("foo", "bar")
@@ -24,20 +29,27 @@ class TestContent(TestCase):
 
     def test___eq__(self):
         content_type = ContentType("foo", "bar")
-        content1 = Content(content_type, lambda: ["bytes"])
-        content2 = Content(content_type, lambda: ["bytes"])
-        content3 = Content(content_type, lambda: ["by", "tes"])
-        content4 = Content(content_type, lambda: ["by", "te"])
-        content5 = Content(ContentType("f", "b"), lambda: ["by", "tes"])
+        one_chunk = lambda: [_b("bytes")]
+        two_chunk = lambda: [_b("by"), _b("tes")]
+        content1 = Content(content_type, one_chunk)
+        content2 = Content(content_type, one_chunk)
+        content3 = Content(content_type, two_chunk)
+        content4 = Content(content_type, lambda: [_b("by"), _b("te")])
+        content5 = Content(ContentType("f", "b"), two_chunk)
         self.assertEqual(content1, content2)
         self.assertEqual(content1, content3)
         self.assertNotEqual(content1, content4)
         self.assertNotEqual(content1, content5)
 
+    def test___repr__(self):
+        content = Content(ContentType("application", "octet-stream"),
+            lambda: [_b("\x00bin"), _b("ary\xff")])
+        self.assertIn("\\x00binary\\xff", repr(content))
+
     def test_iter_text_not_text_errors(self):
         content_type = ContentType("foo", "bar")
         content = Content(content_type, lambda: ["bytes"])
-        self.assertRaises(ValueError, content.iter_text)
+        self.assertThat(content.iter_text, raises_value_error)
 
     def test_iter_text_decodes(self):
         content_type = ContentType("text", "strange", {"charset": "utf8"})
@@ -56,7 +68,8 @@ class TestContent(TestCase):
 class TestTracebackContent(TestCase):
 
     def test___init___None_errors(self):
-        self.assertRaises(ValueError, TracebackContent, None, None)
+        self.assertThat(lambda:TracebackContent(None, None),
+            raises_value_error) 
 
     def test___init___sets_ivars(self):
         content = TracebackContent(an_exc_info, self)
@@ -68,6 +81,14 @@ class TestTracebackContent(TestCase):
         self.assertEqual(expected, ''.join(list(content.iter_text())))
 
 
+class TestBytesContent(TestCase):
+
+    def test_bytes(self):
+        data = _u("some data")
+        expected = Content(UTF8_TEXT, lambda: [data.encode('utf8')])
+        self.assertEqual(expected, text_content(data))
+
+
 def test_suite():
     from unittest import TestLoader
     return TestLoader().loadTestsFromName(__name__)
index d593a14eaf014fad4da8cf82e2b9b4c3d1e4d686..52f4afac0559e2112bab211bae74b493001b08f7 100644 (file)
@@ -1,16 +1,18 @@
 # Copyright (c) 2008 Jonathan M. Lange. See LICENSE for details.
 
 from testtools import TestCase
-from testtools.matchers import Equals
+from testtools.matchers import Equals, MatchesException, Raises
 from testtools.content_type import ContentType, UTF8_TEXT
 
 
 class TestContentType(TestCase):
 
     def test___init___None_errors(self):
-        self.assertRaises(ValueError, ContentType, None, None)
-        self.assertRaises(ValueError, ContentType, None, "traceback")
-        self.assertRaises(ValueError, ContentType, "text", None)
+        raises_value_error = Raises(MatchesException(ValueError))
+        self.assertThat(lambda:ContentType(None, None), raises_value_error)
+        self.assertThat(lambda:ContentType(None, "traceback"),
+            raises_value_error)
+        self.assertThat(lambda:ContentType("text", None), raises_value_error)
 
     def test___init___sets_ivars(self):
         content_type = ContentType("foo", "bar")
diff --git a/lib/testtools/testtools/tests/test_deferredruntest.py b/lib/testtools/testtools/tests/test_deferredruntest.py
new file mode 100644 (file)
index 0000000..04614df
--- /dev/null
@@ -0,0 +1,738 @@
+# Copyright (c) 2010 Jonathan M. Lange. See LICENSE for details.
+
+"""Tests for the DeferredRunTest single test execution logic."""
+
+import os
+import signal
+
+from testtools import (
+    skipIf,
+    TestCase,
+    )
+from testtools.content import (
+    text_content,
+    )
+from testtools.helpers import try_import
+from testtools.tests.helpers import ExtendedTestResult
+from testtools.matchers import (
+    Equals,
+    KeysEqual,
+    MatchesException,
+    Raises,
+    )
+from testtools.runtest import RunTest
+from testtools.tests.test_spinner import NeedsTwistedTestCase
+
+assert_fails_with = try_import('testtools.deferredruntest.assert_fails_with')
+AsynchronousDeferredRunTest = try_import(
+    'testtools.deferredruntest.AsynchronousDeferredRunTest')
+flush_logged_errors = try_import(
+    'testtools.deferredruntest.flush_logged_errors')
+SynchronousDeferredRunTest = try_import(
+    'testtools.deferredruntest.SynchronousDeferredRunTest')
+
+defer = try_import('twisted.internet.defer')
+failure = try_import('twisted.python.failure')
+log = try_import('twisted.python.log')
+DelayedCall = try_import('twisted.internet.base.DelayedCall')
+
+
+class X(object):
+    """Tests that we run as part of our tests, nested to avoid discovery."""
+
+    class Base(TestCase):
+        def setUp(self):
+            super(X.Base, self).setUp()
+            self.calls = ['setUp']
+            self.addCleanup(self.calls.append, 'clean-up')
+        def test_something(self):
+            self.calls.append('test')
+        def tearDown(self):
+            self.calls.append('tearDown')
+            super(X.Base, self).tearDown()
+
+    class ErrorInSetup(Base):
+        expected_calls = ['setUp', 'clean-up']
+        expected_results = [('addError', RuntimeError)]
+        def setUp(self):
+            super(X.ErrorInSetup, self).setUp()
+            raise RuntimeError("Error in setUp")
+
+    class ErrorInTest(Base):
+        expected_calls = ['setUp', 'tearDown', 'clean-up']
+        expected_results = [('addError', RuntimeError)]
+        def test_something(self):
+            raise RuntimeError("Error in test")
+
+    class FailureInTest(Base):
+        expected_calls = ['setUp', 'tearDown', 'clean-up']
+        expected_results = [('addFailure', AssertionError)]
+        def test_something(self):
+            self.fail("test failed")
+
+    class ErrorInTearDown(Base):
+        expected_calls = ['setUp', 'test', 'clean-up']
+        expected_results = [('addError', RuntimeError)]
+        def tearDown(self):
+            raise RuntimeError("Error in tearDown")
+
+    class ErrorInCleanup(Base):
+        expected_calls = ['setUp', 'test', 'tearDown', 'clean-up']
+        expected_results = [('addError', ZeroDivisionError)]
+        def test_something(self):
+            self.calls.append('test')
+            self.addCleanup(lambda: 1/0)
+
+    class TestIntegration(NeedsTwistedTestCase):
+
+        def assertResultsMatch(self, test, result):
+            events = list(result._events)
+            self.assertEqual(('startTest', test), events.pop(0))
+            for expected_result in test.expected_results:
+                result = events.pop(0)
+                if len(expected_result) == 1:
+                    self.assertEqual((expected_result[0], test), result)
+                else:
+                    self.assertEqual((expected_result[0], test), result[:2])
+                    error_type = expected_result[1]
+                    self.assertIn(error_type.__name__, str(result[2]))
+            self.assertEqual([('stopTest', test)], events)
+
+        def test_runner(self):
+            result = ExtendedTestResult()
+            test = self.test_factory('test_something', runTest=self.runner)
+            test.run(result)
+            self.assertEqual(test.calls, self.test_factory.expected_calls)
+            self.assertResultsMatch(test, result)
+
+
+def make_integration_tests():
+    from unittest import TestSuite
+    from testtools import clone_test_with_new_id
+    runners = [
+        ('RunTest', RunTest),
+        ('SynchronousDeferredRunTest', SynchronousDeferredRunTest),
+        ('AsynchronousDeferredRunTest', AsynchronousDeferredRunTest),
+        ]
+
+    tests = [
+        X.ErrorInSetup,
+        X.ErrorInTest,
+        X.ErrorInTearDown,
+        X.FailureInTest,
+        X.ErrorInCleanup,
+        ]
+    base_test = X.TestIntegration('test_runner')
+    integration_tests = []
+    for runner_name, runner in runners:
+        for test in tests:
+            new_test = clone_test_with_new_id(
+                base_test, '%s(%s, %s)' % (
+                    base_test.id(),
+                    runner_name,
+                    test.__name__))
+            new_test.test_factory = test
+            new_test.runner = runner
+            integration_tests.append(new_test)
+    return TestSuite(integration_tests)
+
+
+class TestSynchronousDeferredRunTest(NeedsTwistedTestCase):
+
+    def make_result(self):
+        return ExtendedTestResult()
+
+    def make_runner(self, test):
+        return SynchronousDeferredRunTest(test, test.exception_handlers)
+
+    def test_success(self):
+        class SomeCase(TestCase):
+            def test_success(self):
+                return defer.succeed(None)
+        test = SomeCase('test_success')
+        runner = self.make_runner(test)
+        result = self.make_result()
+        runner.run(result)
+        self.assertThat(
+            result._events, Equals([
+                ('startTest', test),
+                ('addSuccess', test),
+                ('stopTest', test)]))
+
+    def test_failure(self):
+        class SomeCase(TestCase):
+            def test_failure(self):
+                return defer.maybeDeferred(self.fail, "Egads!")
+        test = SomeCase('test_failure')
+        runner = self.make_runner(test)
+        result = self.make_result()
+        runner.run(result)
+        self.assertThat(
+            [event[:2] for event in result._events], Equals([
+                ('startTest', test),
+                ('addFailure', test),
+                ('stopTest', test)]))
+
+    def test_setUp_followed_by_test(self):
+        class SomeCase(TestCase):
+            def setUp(self):
+                super(SomeCase, self).setUp()
+                return defer.succeed(None)
+            def test_failure(self):
+                return defer.maybeDeferred(self.fail, "Egads!")
+        test = SomeCase('test_failure')
+        runner = self.make_runner(test)
+        result = self.make_result()
+        runner.run(result)
+        self.assertThat(
+            [event[:2] for event in result._events], Equals([
+                ('startTest', test),
+                ('addFailure', test),
+                ('stopTest', test)]))
+
+
+class TestAsynchronousDeferredRunTest(NeedsTwistedTestCase):
+
+    def make_reactor(self):
+        from twisted.internet import reactor
+        return reactor
+
+    def make_result(self):
+        return ExtendedTestResult()
+
+    def make_runner(self, test, timeout=None):
+        if timeout is None:
+            timeout = self.make_timeout()
+        return AsynchronousDeferredRunTest(
+            test, test.exception_handlers, timeout=timeout)
+
+    def make_timeout(self):
+        return 0.005
+
+    def test_setUp_returns_deferred_that_fires_later(self):
+        # setUp can return a Deferred that might fire at any time.
+        # AsynchronousDeferredRunTest will not go on to running the test until
+        # the Deferred returned by setUp actually fires.
+        call_log = []
+        marker = object()
+        d = defer.Deferred().addCallback(call_log.append)
+        class SomeCase(TestCase):
+            def setUp(self):
+                super(SomeCase, self).setUp()
+                call_log.append('setUp')
+                return d
+            def test_something(self):
+                call_log.append('test')
+        def fire_deferred():
+            self.assertThat(call_log, Equals(['setUp']))
+            d.callback(marker)
+        test = SomeCase('test_something')
+        timeout = self.make_timeout()
+        runner = self.make_runner(test, timeout=timeout)
+        result = self.make_result()
+        reactor = self.make_reactor()
+        reactor.callLater(timeout, fire_deferred)
+        runner.run(result)
+        self.assertThat(call_log, Equals(['setUp', marker, 'test']))
+
+    def test_calls_setUp_test_tearDown_in_sequence(self):
+        # setUp, the test method and tearDown can all return
+        # Deferreds. AsynchronousDeferredRunTest will make sure that each of
+        # these are run in turn, only going on to the next stage once the
+        # Deferred from the previous stage has fired.
+        call_log = []
+        a = defer.Deferred()
+        a.addCallback(lambda x: call_log.append('a'))
+        b = defer.Deferred()
+        b.addCallback(lambda x: call_log.append('b'))
+        c = defer.Deferred()
+        c.addCallback(lambda x: call_log.append('c'))
+        class SomeCase(TestCase):
+            def setUp(self):
+                super(SomeCase, self).setUp()
+                call_log.append('setUp')
+                return a
+            def test_success(self):
+                call_log.append('test')
+                return b
+            def tearDown(self):
+                super(SomeCase, self).tearDown()
+                call_log.append('tearDown')
+                return c
+        test = SomeCase('test_success')
+        timeout = self.make_timeout()
+        runner = self.make_runner(test, timeout)
+        result = self.make_result()
+        reactor = self.make_reactor()
+        def fire_a():
+            self.assertThat(call_log, Equals(['setUp']))
+            a.callback(None)
+        def fire_b():
+            self.assertThat(call_log, Equals(['setUp', 'a', 'test']))
+            b.callback(None)
+        def fire_c():
+            self.assertThat(
+                call_log, Equals(['setUp', 'a', 'test', 'b', 'tearDown']))
+            c.callback(None)
+        reactor.callLater(timeout * 0.25, fire_a)
+        reactor.callLater(timeout * 0.5, fire_b)
+        reactor.callLater(timeout * 0.75, fire_c)
+        runner.run(result)
+        self.assertThat(
+            call_log, Equals(['setUp', 'a', 'test', 'b', 'tearDown', 'c']))
+
+    def test_async_cleanups(self):
+        # Cleanups added with addCleanup can return
+        # Deferreds. AsynchronousDeferredRunTest will run each of them in
+        # turn.
+        class SomeCase(TestCase):
+            def test_whatever(self):
+                pass
+        test = SomeCase('test_whatever')
+        call_log = []
+        a = defer.Deferred().addCallback(lambda x: call_log.append('a'))
+        b = defer.Deferred().addCallback(lambda x: call_log.append('b'))
+        c = defer.Deferred().addCallback(lambda x: call_log.append('c'))
+        test.addCleanup(lambda: a)
+        test.addCleanup(lambda: b)
+        test.addCleanup(lambda: c)
+        def fire_a():
+            self.assertThat(call_log, Equals([]))
+            a.callback(None)
+        def fire_b():
+            self.assertThat(call_log, Equals(['a']))
+            b.callback(None)
+        def fire_c():
+            self.assertThat(call_log, Equals(['a', 'b']))
+            c.callback(None)
+        timeout = self.make_timeout()
+        reactor = self.make_reactor()
+        reactor.callLater(timeout * 0.25, fire_a)
+        reactor.callLater(timeout * 0.5, fire_b)
+        reactor.callLater(timeout * 0.75, fire_c)
+        runner = self.make_runner(test, timeout)
+        result = self.make_result()
+        runner.run(result)
+        self.assertThat(call_log, Equals(['a', 'b', 'c']))
+
+    def test_clean_reactor(self):
+        # If there's cruft left over in the reactor, the test fails.
+        reactor = self.make_reactor()
+        timeout = self.make_timeout()
+        class SomeCase(TestCase):
+            def test_cruft(self):
+                reactor.callLater(timeout * 10.0, lambda: None)
+        test = SomeCase('test_cruft')
+        runner = self.make_runner(test, timeout)
+        result = self.make_result()
+        runner.run(result)
+        self.assertThat(
+            [event[:2] for event in result._events],
+            Equals(
+                [('startTest', test),
+                 ('addError', test),
+                 ('stopTest', test)]))
+        error = result._events[1][2]
+        self.assertThat(error, KeysEqual('traceback', 'twisted-log'))
+
+    def test_unhandled_error_from_deferred(self):
+        # If there's a Deferred with an unhandled error, the test fails.  Each
+        # unhandled error is reported with a separate traceback.
+        class SomeCase(TestCase):
+            def test_cruft(self):
+                # Note we aren't returning the Deferred so that the error will
+                # be unhandled.
+                defer.maybeDeferred(lambda: 1/0)
+                defer.maybeDeferred(lambda: 2/0)
+        test = SomeCase('test_cruft')
+        runner = self.make_runner(test)
+        result = self.make_result()
+        runner.run(result)
+        error = result._events[1][2]
+        result._events[1] = ('addError', test, None)
+        self.assertThat(result._events, Equals(
+            [('startTest', test),
+             ('addError', test, None),
+             ('stopTest', test)]))
+        self.assertThat(
+            error, KeysEqual(
+                'twisted-log',
+                'unhandled-error-in-deferred',
+                'unhandled-error-in-deferred-1',
+                ))
+
+    def test_unhandled_error_from_deferred_combined_with_error(self):
+        # If there's a Deferred with an unhandled error, the test fails.  Each
+        # unhandled error is reported with a separate traceback, and the error
+        # is still reported.
+        class SomeCase(TestCase):
+            def test_cruft(self):
+                # Note we aren't returning the Deferred so that the error will
+                # be unhandled.
+                defer.maybeDeferred(lambda: 1/0)
+                2 / 0
+        test = SomeCase('test_cruft')
+        runner = self.make_runner(test)
+        result = self.make_result()
+        runner.run(result)
+        error = result._events[1][2]
+        result._events[1] = ('addError', test, None)
+        self.assertThat(result._events, Equals(
+            [('startTest', test),
+             ('addError', test, None),
+             ('stopTest', test)]))
+        self.assertThat(
+            error, KeysEqual(
+                'traceback',
+                'twisted-log',
+                'unhandled-error-in-deferred',
+                ))
+
+    @skipIf(os.name != "posix", "Sending SIGINT with os.kill is posix only")
+    def test_keyboard_interrupt_stops_test_run(self):
+        # If we get a SIGINT during a test run, the test stops and no more
+        # tests run.
+        SIGINT = getattr(signal, 'SIGINT', None)
+        if not SIGINT:
+            raise self.skipTest("SIGINT unavailable")
+        class SomeCase(TestCase):
+            def test_pause(self):
+                return defer.Deferred()
+        test = SomeCase('test_pause')
+        reactor = self.make_reactor()
+        timeout = self.make_timeout()
+        runner = self.make_runner(test, timeout * 5)
+        result = self.make_result()
+        reactor.callLater(timeout, os.kill, os.getpid(), SIGINT)
+        self.assertThat(lambda:runner.run(result),
+            Raises(MatchesException(KeyboardInterrupt)))
+
+    @skipIf(os.name != "posix", "Sending SIGINT with os.kill is posix only")
+    def test_fast_keyboard_interrupt_stops_test_run(self):
+        # If we get a SIGINT during a test run, the test stops and no more
+        # tests run.
+        SIGINT = getattr(signal, 'SIGINT', None)
+        if not SIGINT:
+            raise self.skipTest("SIGINT unavailable")
+        class SomeCase(TestCase):
+            def test_pause(self):
+                return defer.Deferred()
+        test = SomeCase('test_pause')
+        reactor = self.make_reactor()
+        timeout = self.make_timeout()
+        runner = self.make_runner(test, timeout * 5)
+        result = self.make_result()
+        reactor.callWhenRunning(os.kill, os.getpid(), SIGINT)
+        self.assertThat(lambda:runner.run(result),
+            Raises(MatchesException(KeyboardInterrupt)))
+
+    def test_timeout_causes_test_error(self):
+        # If a test times out, it reports itself as having failed with a
+        # TimeoutError.
+        class SomeCase(TestCase):
+            def test_pause(self):
+                return defer.Deferred()
+        test = SomeCase('test_pause')
+        runner = self.make_runner(test)
+        result = self.make_result()
+        runner.run(result)
+        error = result._events[1][2]
+        self.assertThat(
+            [event[:2] for event in result._events], Equals(
+            [('startTest', test),
+             ('addError', test),
+             ('stopTest', test)]))
+        self.assertIn('TimeoutError', str(error['traceback']))
+
+    def test_convenient_construction(self):
+        # As a convenience method, AsynchronousDeferredRunTest has a
+        # classmethod that returns an AsynchronousDeferredRunTest
+        # factory. This factory has the same API as the RunTest constructor.
+        reactor = object()
+        timeout = object()
+        handler = object()
+        factory = AsynchronousDeferredRunTest.make_factory(reactor, timeout)
+        runner = factory(self, [handler])
+        self.assertIs(reactor, runner._reactor)
+        self.assertIs(timeout, runner._timeout)
+        self.assertIs(self, runner.case)
+        self.assertEqual([handler], runner.handlers)
+
+    def test_use_convenient_factory(self):
+        # Make sure that the factory can actually be used.
+        factory = AsynchronousDeferredRunTest.make_factory()
+        class SomeCase(TestCase):
+            run_tests_with = factory
+            def test_something(self):
+                pass
+        case = SomeCase('test_something')
+        case.run()
+
+    def test_convenient_construction_default_reactor(self):
+        # As a convenience method, AsynchronousDeferredRunTest has a
+        # classmethod that returns an AsynchronousDeferredRunTest
+        # factory. This factory has the same API as the RunTest constructor.
+        reactor = object()
+        handler = object()
+        factory = AsynchronousDeferredRunTest.make_factory(reactor=reactor)
+        runner = factory(self, [handler])
+        self.assertIs(reactor, runner._reactor)
+        self.assertIs(self, runner.case)
+        self.assertEqual([handler], runner.handlers)
+
+    def test_convenient_construction_default_timeout(self):
+        # As a convenience method, AsynchronousDeferredRunTest has a
+        # classmethod that returns an AsynchronousDeferredRunTest
+        # factory. This factory has the same API as the RunTest constructor.
+        timeout = object()
+        handler = object()
+        factory = AsynchronousDeferredRunTest.make_factory(timeout=timeout)
+        runner = factory(self, [handler])
+        self.assertIs(timeout, runner._timeout)
+        self.assertIs(self, runner.case)
+        self.assertEqual([handler], runner.handlers)
+
+    def test_convenient_construction_default_debugging(self):
+        # As a convenience method, AsynchronousDeferredRunTest has a
+        # classmethod that returns an AsynchronousDeferredRunTest
+        # factory. This factory has the same API as the RunTest constructor.
+        handler = object()
+        factory = AsynchronousDeferredRunTest.make_factory(debug=True)
+        runner = factory(self, [handler])
+        self.assertIs(self, runner.case)
+        self.assertEqual([handler], runner.handlers)
+        self.assertEqual(True, runner._debug)
+
+    def test_deferred_error(self):
+        class SomeTest(TestCase):
+            def test_something(self):
+                return defer.maybeDeferred(lambda: 1/0)
+        test = SomeTest('test_something')
+        runner = self.make_runner(test)
+        result = self.make_result()
+        runner.run(result)
+        self.assertThat(
+            [event[:2] for event in result._events],
+            Equals([
+                ('startTest', test),
+                ('addError', test),
+                ('stopTest', test)]))
+        error = result._events[1][2]
+        self.assertThat(error, KeysEqual('traceback', 'twisted-log'))
+
+    def test_only_addError_once(self):
+        # Even if the reactor is unclean and the test raises an error and the
+        # cleanups raise errors, we only called addError once per test.
+        reactor = self.make_reactor()
+        class WhenItRains(TestCase):
+            def it_pours(self):
+                # Add a dirty cleanup.
+                self.addCleanup(lambda: 3 / 0)
+                # Dirty the reactor.
+                from twisted.internet.protocol import ServerFactory
+                reactor.listenTCP(0, ServerFactory())
+                # Unhandled error.
+                defer.maybeDeferred(lambda: 2 / 0)
+                # Actual error.
+                raise RuntimeError("Excess precipitation")
+        test = WhenItRains('it_pours')
+        runner = self.make_runner(test)
+        result = self.make_result()
+        runner.run(result)
+        self.assertThat(
+            [event[:2] for event in result._events],
+            Equals([
+                ('startTest', test),
+                ('addError', test),
+                ('stopTest', test)]))
+        error = result._events[1][2]
+        self.assertThat(
+            error, KeysEqual(
+                'traceback',
+                'traceback-1',
+                'traceback-2',
+                'twisted-log',
+                'unhandled-error-in-deferred',
+                ))
+
+    def test_log_err_is_error(self):
+        # An error logged during the test run is recorded as an error in the
+        # tests.
+        class LogAnError(TestCase):
+            def test_something(self):
+                try:
+                    1/0
+                except ZeroDivisionError:
+                    f = failure.Failure()
+                log.err(f)
+        test = LogAnError('test_something')
+        runner = self.make_runner(test)
+        result = self.make_result()
+        runner.run(result)
+        self.assertThat(
+            [event[:2] for event in result._events],
+            Equals([
+                ('startTest', test),
+                ('addError', test),
+                ('stopTest', test)]))
+        error = result._events[1][2]
+        self.assertThat(error, KeysEqual('logged-error', 'twisted-log'))
+
+    def test_log_err_flushed_is_success(self):
+        # An error logged during the test run is recorded as an error in the
+        # tests.
+        class LogAnError(TestCase):
+            def test_something(self):
+                try:
+                    1/0
+                except ZeroDivisionError:
+                    f = failure.Failure()
+                log.err(f)
+                flush_logged_errors(ZeroDivisionError)
+        test = LogAnError('test_something')
+        runner = self.make_runner(test)
+        result = self.make_result()
+        runner.run(result)
+        self.assertThat(
+            result._events,
+            Equals([
+                ('startTest', test),
+                ('addSuccess', test, {'twisted-log': text_content('')}),
+                ('stopTest', test)]))
+
+    def test_log_in_details(self):
+        class LogAnError(TestCase):
+            def test_something(self):
+                log.msg("foo")
+                1/0
+        test = LogAnError('test_something')
+        runner = self.make_runner(test)
+        result = self.make_result()
+        runner.run(result)
+        self.assertThat(
+            [event[:2] for event in result._events],
+            Equals([
+                ('startTest', test),
+                ('addError', test),
+                ('stopTest', test)]))
+        error = result._events[1][2]
+        self.assertThat(error, KeysEqual('traceback', 'twisted-log'))
+
+    def test_debugging_unchanged_during_test_by_default(self):
+        debugging = [(defer.Deferred.debug, DelayedCall.debug)]
+        class SomeCase(TestCase):
+            def test_debugging_enabled(self):
+                debugging.append((defer.Deferred.debug, DelayedCall.debug))
+        test = SomeCase('test_debugging_enabled')
+        runner = AsynchronousDeferredRunTest(
+            test, handlers=test.exception_handlers,
+            reactor=self.make_reactor(), timeout=self.make_timeout())
+        runner.run(self.make_result())
+        self.assertEqual(debugging[0], debugging[1])
+
+    def test_debugging_enabled_during_test_with_debug_flag(self):
+        self.patch(defer.Deferred, 'debug', False)
+        self.patch(DelayedCall, 'debug', False)
+        debugging = []
+        class SomeCase(TestCase):
+            def test_debugging_enabled(self):
+                debugging.append((defer.Deferred.debug, DelayedCall.debug))
+        test = SomeCase('test_debugging_enabled')
+        runner = AsynchronousDeferredRunTest(
+            test, handlers=test.exception_handlers,
+            reactor=self.make_reactor(), timeout=self.make_timeout(),
+            debug=True)
+        runner.run(self.make_result())
+        self.assertEqual([(True, True)], debugging)
+        self.assertEqual(False, defer.Deferred.debug)
+        self.assertEqual(False, defer.Deferred.debug)
+
+
+class TestAssertFailsWith(NeedsTwistedTestCase):
+    """Tests for `assert_fails_with`."""
+
+    if SynchronousDeferredRunTest is not None:
+        run_tests_with = SynchronousDeferredRunTest
+
+    def test_assert_fails_with_success(self):
+        # assert_fails_with fails the test if it's given a Deferred that
+        # succeeds.
+        marker = object()
+        d = assert_fails_with(defer.succeed(marker), RuntimeError)
+        def check_result(failure):
+            failure.trap(self.failureException)
+            self.assertThat(
+                str(failure.value),
+                Equals("RuntimeError not raised (%r returned)" % (marker,)))
+        d.addCallbacks(
+            lambda x: self.fail("Should not have succeeded"), check_result)
+        return d
+
+    def test_assert_fails_with_success_multiple_types(self):
+        # assert_fails_with fails the test if it's given a Deferred that
+        # succeeds.
+        marker = object()
+        d = assert_fails_with(
+            defer.succeed(marker), RuntimeError, ZeroDivisionError)
+        def check_result(failure):
+            failure.trap(self.failureException)
+            self.assertThat(
+                str(failure.value),
+                Equals("RuntimeError, ZeroDivisionError not raised "
+                       "(%r returned)" % (marker,)))
+        d.addCallbacks(
+            lambda x: self.fail("Should not have succeeded"), check_result)
+        return d
+
+    def test_assert_fails_with_wrong_exception(self):
+        # assert_fails_with fails the test if it's given a Deferred that
+        # succeeds.
+        d = assert_fails_with(
+            defer.maybeDeferred(lambda: 1/0), RuntimeError, KeyboardInterrupt)
+        def check_result(failure):
+            failure.trap(self.failureException)
+            lines = str(failure.value).splitlines()
+            self.assertThat(
+                lines[:2],
+                Equals([
+                    ("ZeroDivisionError raised instead of RuntimeError, "
+                     "KeyboardInterrupt:"),
+                    " Traceback (most recent call last):",
+                    ]))
+        d.addCallbacks(
+            lambda x: self.fail("Should not have succeeded"), check_result)
+        return d
+
+    def test_assert_fails_with_expected_exception(self):
+        # assert_fails_with calls back with the value of the failure if it's
+        # one of the expected types of failures.
+        try:
+            1/0
+        except ZeroDivisionError:
+            f = failure.Failure()
+        d = assert_fails_with(defer.fail(f), ZeroDivisionError)
+        return d.addCallback(self.assertThat, Equals(f.value))
+
+    def test_custom_failure_exception(self):
+        # If assert_fails_with is passed a 'failureException' keyword
+        # argument, then it will raise that instead of `AssertionError`.
+        class CustomException(Exception):
+            pass
+        marker = object()
+        d = assert_fails_with(
+            defer.succeed(marker), RuntimeError,
+            failureException=CustomException)
+        def check_result(failure):
+            failure.trap(CustomException)
+            self.assertThat(
+                str(failure.value),
+                Equals("RuntimeError not raised (%r returned)" % (marker,)))
+        return d.addCallbacks(
+            lambda x: self.fail("Should not have succeeded"), check_result)
+
+
+def test_suite():
+    from unittest import TestLoader, TestSuite
+    return TestSuite(
+        [TestLoader().loadTestsFromName(__name__),
+         make_integration_tests()])
diff --git a/lib/testtools/testtools/tests/test_fixturesupport.py b/lib/testtools/testtools/tests/test_fixturesupport.py
new file mode 100644 (file)
index 0000000..ebdd037
--- /dev/null
@@ -0,0 +1,77 @@
+import unittest
+
+from testtools import (
+    TestCase,
+    content,
+    content_type,
+    )
+from testtools.helpers import try_import
+from testtools.tests.helpers import (
+    ExtendedTestResult,
+    )
+
+fixtures = try_import('fixtures')
+LoggingFixture = try_import('fixtures.tests.helpers.LoggingFixture')
+
+
+class TestFixtureSupport(TestCase):
+
+    def setUp(self):
+        super(TestFixtureSupport, self).setUp()
+        if fixtures is None or LoggingFixture is None:
+            self.skipTest("Need fixtures")
+
+    def test_useFixture(self):
+        fixture = LoggingFixture()
+        class SimpleTest(TestCase):
+            def test_foo(self):
+                self.useFixture(fixture)
+        result = unittest.TestResult()
+        SimpleTest('test_foo').run(result)
+        self.assertTrue(result.wasSuccessful())
+        self.assertEqual(['setUp', 'cleanUp'], fixture.calls)
+
+    def test_useFixture_cleanups_raise_caught(self):
+        calls = []
+        def raiser(ignored):
+            calls.append('called')
+            raise Exception('foo')
+        fixture = fixtures.FunctionFixture(lambda:None, raiser)
+        class SimpleTest(TestCase):
+            def test_foo(self):
+                self.useFixture(fixture)
+        result = unittest.TestResult()
+        SimpleTest('test_foo').run(result)
+        self.assertFalse(result.wasSuccessful())
+        self.assertEqual(['called'], calls)
+
+    def test_useFixture_details_captured(self):
+        class DetailsFixture(fixtures.Fixture):
+            def setUp(self):
+                fixtures.Fixture.setUp(self)
+                self.addCleanup(delattr, self, 'content')
+                self.content = ['content available until cleanUp']
+                self.addDetail('content',
+                    content.Content(content_type.UTF8_TEXT, self.get_content))
+            def get_content(self):
+                return self.content
+        fixture = DetailsFixture()
+        class SimpleTest(TestCase):
+            def test_foo(self):
+                self.useFixture(fixture)
+                # Add a colliding detail (both should show up)
+                self.addDetail('content',
+                    content.Content(content_type.UTF8_TEXT, lambda:['foo']))
+        result = ExtendedTestResult()
+        SimpleTest('test_foo').run(result)
+        self.assertEqual('addSuccess', result._events[-2][0])
+        details = result._events[-2][2]
+        self.assertEqual(['content', 'content-1'], sorted(details.keys()))
+        self.assertEqual('foo', ''.join(details['content'].iter_text()))
+        self.assertEqual('content available until cleanUp',
+            ''.join(details['content-1'].iter_text()))
+
+
+def test_suite():
+    from unittest import TestLoader
+    return TestLoader().loadTestsFromName(__name__)
diff --git a/lib/testtools/testtools/tests/test_helpers.py b/lib/testtools/testtools/tests/test_helpers.py
new file mode 100644 (file)
index 0000000..f1894a4
--- /dev/null
@@ -0,0 +1,106 @@
+# Copyright (c) 2010 Jonathan M. Lange. See LICENSE for details.
+
+from testtools import TestCase
+from testtools.helpers import (
+    try_import,
+    try_imports,
+    )
+from testtools.matchers import (
+    Equals,
+    Is,
+    )
+
+
+class TestTryImport(TestCase):
+
+    def test_doesnt_exist(self):
+        # try_import('thing', foo) returns foo if 'thing' doesn't exist.
+        marker = object()
+        result = try_import('doesntexist', marker)
+        self.assertThat(result, Is(marker))
+
+    def test_None_is_default_alternative(self):
+        # try_import('thing') returns None if 'thing' doesn't exist.
+        result = try_import('doesntexist')
+        self.assertThat(result, Is(None))
+
+    def test_existing_module(self):
+        # try_import('thing', foo) imports 'thing' and returns it if it's a
+        # module that exists.
+        result = try_import('os', object())
+        import os
+        self.assertThat(result, Is(os))
+
+    def test_existing_submodule(self):
+        # try_import('thing.another', foo) imports 'thing' and returns it if
+        # it's a module that exists.
+        result = try_import('os.path', object())
+        import os
+        self.assertThat(result, Is(os.path))
+
+    def test_nonexistent_submodule(self):
+        # try_import('thing.another', foo) imports 'thing' and returns foo if
+        # 'another' doesn't exist.
+        marker = object()
+        result = try_import('os.doesntexist', marker)
+        self.assertThat(result, Is(marker))
+
+    def test_object_from_module(self):
+        # try_import('thing.object') imports 'thing' and returns
+        # 'thing.object' if 'thing' is a module and 'object' is not.
+        result = try_import('os.path.join')
+        import os
+        self.assertThat(result, Is(os.path.join))
+
+
+class TestTryImports(TestCase):
+
+    def test_doesnt_exist(self):
+        # try_imports('thing', foo) returns foo if 'thing' doesn't exist.
+        marker = object()
+        result = try_imports(['doesntexist'], marker)
+        self.assertThat(result, Is(marker))
+
+    def test_fallback(self):
+        result = try_imports(['doesntexist', 'os'])
+        import os
+        self.assertThat(result, Is(os))
+
+    def test_None_is_default_alternative(self):
+        # try_imports('thing') returns None if 'thing' doesn't exist.
+        e = self.assertRaises(
+            ImportError, try_imports, ['doesntexist', 'noreally'])
+        self.assertThat(
+            str(e),
+            Equals("Could not import any of: doesntexist, noreally"))
+
+    def test_existing_module(self):
+        # try_imports('thing', foo) imports 'thing' and returns it if it's a
+        # module that exists.
+        result = try_imports(['os'], object())
+        import os
+        self.assertThat(result, Is(os))
+
+    def test_existing_submodule(self):
+        # try_imports('thing.another', foo) imports 'thing' and returns it if
+        # it's a module that exists.
+        result = try_imports(['os.path'], object())
+        import os
+        self.assertThat(result, Is(os.path))
+
+    def test_nonexistent_submodule(self):
+        # try_imports('thing.another', foo) imports 'thing' and returns foo if
+        # 'another' doesn't exist.
+        marker = object()
+        result = try_imports(['os.doesntexist'], marker)
+        self.assertThat(result, Is(marker))
+
+    def test_fallback_submodule(self):
+        result = try_imports(['os.doesntexist', 'os.path'])
+        import os
+        self.assertThat(result, Is(os.path))
+
+
+def test_suite():
+    from unittest import TestLoader
+    return TestLoader().loadTestsFromName(__name__)
index 164a6a0c50fbe019440457c18129768711529ec5..9cc2c010efe7d0002698e4da438d02698f71e485 100644 (file)
@@ -1,8 +1,9 @@
-# Copyright (c) 2008 Jonathan M. Lange. See LICENSE for details.
+# Copyright (c) 2008-2010 Jonathan M. Lange. See LICENSE for details.
 
 """Tests for matchers."""
 
 import doctest
+import sys
 
 from testtools import (
     Matcher, # check that Matcher is exposed at the top level for docs.
@@ -12,13 +13,21 @@ from testtools.matchers import (
     Annotate,
     Equals,
     DocTestMatches,
+    DoesNotEndWith,
+    DoesNotStartWith,
+    EndsWith,
+    KeysEqual,
     Is,
     LessThan,
     MatchesAny,
     MatchesAll,
+    MatchesException,
     Mismatch,
     Not,
     NotEquals,
+    Raises,
+    raises,
+    StartsWith,
     )
 
 # Silence pyflakes.
@@ -34,7 +43,8 @@ class TestMismatch(TestCase):
 
     def test_constructor_no_arguments(self):
         mismatch = Mismatch()
-        self.assertRaises(NotImplementedError, mismatch.describe)
+        self.assertThat(mismatch.describe,
+            Raises(MatchesException(NotImplementedError)))
         self.assertEqual({}, mismatch.get_details())
 
 
@@ -152,6 +162,58 @@ class TestLessThanInterface(TestCase, TestMatchersInterface):
     describe_examples = [('4 is >= 4', 4, LessThan(4))]
 
 
+def make_error(type, *args, **kwargs):
+    try:
+        raise type(*args, **kwargs)
+    except type:
+        return sys.exc_info()
+
+
+class TestMatchesExceptionInstanceInterface(TestCase, TestMatchersInterface):
+
+    matches_matcher = MatchesException(ValueError("foo"))
+    error_foo = make_error(ValueError, 'foo')
+    error_bar = make_error(ValueError, 'bar')
+    error_base_foo = make_error(Exception, 'foo')
+    matches_matches = [error_foo]
+    matches_mismatches = [error_bar, error_base_foo]
+
+    str_examples = [
+        ("MatchesException(Exception('foo',))",
+         MatchesException(Exception('foo')))
+        ]
+    describe_examples = [
+        ("<type 'exceptions.Exception'> is not a "
+         "<type 'exceptions.ValueError'>",
+         error_base_foo,
+         MatchesException(ValueError("foo"))),
+        ("ValueError('bar',) has different arguments to ValueError('foo',).",
+         error_bar,
+         MatchesException(ValueError("foo"))),
+        ]
+
+
+class TestMatchesExceptionTypeInterface(TestCase, TestMatchersInterface):
+
+    matches_matcher = MatchesException(ValueError)
+    error_foo = make_error(ValueError, 'foo')
+    error_sub = make_error(UnicodeError, 'bar')
+    error_base_foo = make_error(Exception, 'foo')
+    matches_matches = [error_foo, error_sub]
+    matches_mismatches = [error_base_foo]
+
+    str_examples = [
+        ("MatchesException(<type 'exceptions.Exception'>)",
+         MatchesException(Exception))
+        ]
+    describe_examples = [
+        ("<type 'exceptions.Exception'> is not a "
+         "<type 'exceptions.ValueError'>",
+         error_base_foo,
+         MatchesException(ValueError)),
+        ]
+
+
 class TestNotInterface(TestCase, TestMatchersInterface):
 
     matches_matcher = Not(Equals(1))
@@ -209,6 +271,31 @@ class TestMatchesAllInterface(TestCase, TestMatchersInterface):
                           1, MatchesAll(NotEquals(1), NotEquals(2)))]
 
 
+class TestKeysEqual(TestCase, TestMatchersInterface):
+
+    matches_matcher = KeysEqual('foo', 'bar')
+    matches_matches = [
+        {'foo': 0, 'bar': 1},
+        ]
+    matches_mismatches = [
+        {},
+        {'foo': 0},
+        {'bar': 1},
+        {'foo': 0, 'bar': 1, 'baz': 2},
+        {'a': None, 'b': None, 'c': None},
+        ]
+
+    str_examples = [
+        ("KeysEqual('foo', 'bar')", KeysEqual('foo', 'bar')),
+        ]
+
+    describe_examples = [
+        ("['bar', 'foo'] does not match {'baz': 2, 'foo': 0, 'bar': 1}: "
+         "Keys not equal",
+         {'foo': 0, 'bar': 1, 'baz': 2}, KeysEqual('foo', 'bar')),
+        ]
+
+
 class TestAnnotate(TestCase, TestMatchersInterface):
 
     matches_matcher = Annotate("foo", Equals(1))
@@ -221,6 +308,143 @@ class TestAnnotate(TestCase, TestMatchersInterface):
     describe_examples = [("1 != 2: foo", 2, Annotate('foo', Equals(1)))]
 
 
+class TestRaisesInterface(TestCase, TestMatchersInterface):
+
+    matches_matcher = Raises()
+    def boom():
+        raise Exception('foo')
+    matches_matches = [boom]
+    matches_mismatches = [lambda:None]
+
+    # Tricky to get function objects to render constantly, and the interfaces
+    # helper uses assertEqual rather than (for instance) DocTestMatches.
+    str_examples = []
+
+    describe_examples = []
+
+
+class TestRaisesExceptionMatcherInterface(TestCase, TestMatchersInterface):
+
+    matches_matcher = Raises(
+        exception_matcher=MatchesException(Exception('foo')))
+    def boom_bar():
+        raise Exception('bar')
+    def boom_foo():
+        raise Exception('foo')
+    matches_matches = [boom_foo]
+    matches_mismatches = [lambda:None, boom_bar]
+
+    # Tricky to get function objects to render constantly, and the interfaces
+    # helper uses assertEqual rather than (for instance) DocTestMatches.
+    str_examples = []
+
+    describe_examples = []
+
+
+class TestRaisesBaseTypes(TestCase):
+
+    def raiser(self):
+        raise KeyboardInterrupt('foo')
+
+    def test_KeyboardInterrupt_matched(self):
+        # When KeyboardInterrupt is matched, it is swallowed.
+        matcher = Raises(MatchesException(KeyboardInterrupt))
+        self.assertThat(self.raiser, matcher)
+
+    def test_KeyboardInterrupt_propogates(self):
+        # The default 'it raised' propogates KeyboardInterrupt.
+        match_keyb = Raises(MatchesException(KeyboardInterrupt))
+        def raise_keyb_from_match():
+            matcher = Raises()
+            matcher.match(self.raiser)
+        self.assertThat(raise_keyb_from_match, match_keyb)
+
+    def test_KeyboardInterrupt_match_Exception_propogates(self):
+        # If the raised exception isn't matched, and it is not a subclass of
+        # Exception, it is propogated.
+        match_keyb = Raises(MatchesException(KeyboardInterrupt))
+        def raise_keyb_from_match():
+            matcher = Raises(MatchesException(Exception))
+            matcher.match(self.raiser)
+        self.assertThat(raise_keyb_from_match, match_keyb)
+
+
+class TestRaisesConvenience(TestCase):
+
+    def test_exc_type(self):
+        self.assertThat(lambda: 1/0, raises(ZeroDivisionError))
+
+    def test_exc_value(self):
+        e = RuntimeError("You lose!")
+        def raiser():
+            raise e
+        self.assertThat(raiser, raises(e))
+
+
+class DoesNotStartWithTests(TestCase):
+
+    def test_describe(self):
+        mismatch = DoesNotStartWith("fo", "bo")
+        self.assertEqual("'fo' does not start with 'bo'.", mismatch.describe())
+
+
+class StartsWithTests(TestCase):
+
+    def test_str(self):
+        matcher = StartsWith("bar")
+        self.assertEqual("Starts with 'bar'.", str(matcher))
+
+    def test_match(self):
+        matcher = StartsWith("bar")
+        self.assertIs(None, matcher.match("barf"))
+
+    def test_mismatch_returns_does_not_start_with(self):
+        matcher = StartsWith("bar")
+        self.assertIsInstance(matcher.match("foo"), DoesNotStartWith)
+
+    def test_mismatch_sets_matchee(self):
+        matcher = StartsWith("bar")
+        mismatch = matcher.match("foo")
+        self.assertEqual("foo", mismatch.matchee)
+
+    def test_mismatch_sets_expected(self):
+        matcher = StartsWith("bar")
+        mismatch = matcher.match("foo")
+        self.assertEqual("bar", mismatch.expected)
+
+
+class DoesNotEndWithTests(TestCase):
+
+    def test_describe(self):
+        mismatch = DoesNotEndWith("fo", "bo")
+        self.assertEqual("'fo' does not end with 'bo'.", mismatch.describe())
+
+
+class EndsWithTests(TestCase):
+
+    def test_str(self):
+        matcher = EndsWith("bar")
+        self.assertEqual("Ends with 'bar'.", str(matcher))
+
+    def test_match(self):
+        matcher = EndsWith("arf")
+        self.assertIs(None, matcher.match("barf"))
+
+    def test_mismatch_returns_does_not_end_with(self):
+        matcher = EndsWith("bar")
+        self.assertIsInstance(matcher.match("foo"), DoesNotEndWith)
+
+    def test_mismatch_sets_matchee(self):
+        matcher = EndsWith("bar")
+        mismatch = matcher.match("foo")
+        self.assertEqual("foo", mismatch.matchee)
+
+    def test_mismatch_sets_expected(self):
+        matcher = EndsWith("bar")
+        mismatch = matcher.match("foo")
+        self.assertEqual("bar", mismatch.expected)
+
+
 def test_suite():
     from unittest import TestLoader
     return TestLoader().loadTestsFromName(__name__)
index 09388b22f1e9fe582a1e68b0653ba0939c98ddd9..540a2ee909fe5cfce2601ec15cb17467b36e144f 100644 (file)
@@ -4,6 +4,7 @@
 """Tests for testtools.monkey."""
 
 from testtools import TestCase
+from testtools.matchers import MatchesException, Raises
 from testtools.monkey import MonkeyPatcher, patch
 
 
@@ -132,13 +133,13 @@ class MonkeyPatcherTest(TestCase):
         def _():
             self.assertEquals(self.test_object.foo, 'haha')
             self.assertEquals(self.test_object.bar, 'blahblah')
-            raise RuntimeError, "Something went wrong!"
+            raise RuntimeError("Something went wrong!")
 
         self.monkey_patcher.add_patch(self.test_object, 'foo', 'haha')
         self.monkey_patcher.add_patch(self.test_object, 'bar', 'blahblah')
 
-        self.assertRaises(
-            RuntimeError, self.monkey_patcher.run_with_patches, _)
+        self.assertThat(lambda:self.monkey_patcher.run_with_patches(_),
+            Raises(MatchesException(RuntimeError("Something went wrong!"))))
         self.assertEquals(self.test_object.foo, self.original_object.foo)
         self.assertEquals(self.test_object.bar, self.original_object.bar)
 
diff --git a/lib/testtools/testtools/tests/test_run.py b/lib/testtools/testtools/tests/test_run.py
new file mode 100644 (file)
index 0000000..5087527
--- /dev/null
@@ -0,0 +1,77 @@
+# Copyright (c) 2010 Testtools authors. See LICENSE for details.
+
+"""Tests for the test runner logic."""
+
+import StringIO
+
+from testtools.helpers import try_import
+fixtures = try_import('fixtures')
+
+import testtools
+from testtools import TestCase, run
+
+
+if fixtures:
+    class SampleTestFixture(fixtures.Fixture):
+        """Creates testtools.runexample temporarily."""
+
+        def __init__(self):
+            self.package = fixtures.PythonPackage(
+            'runexample', [('__init__.py', """
+from testtools import TestCase
+
+class TestFoo(TestCase):
+    def test_bar(self):
+        pass
+    def test_quux(self):
+        pass
+def test_suite():
+    from unittest import TestLoader
+    return TestLoader().loadTestsFromName(__name__)
+""")])
+
+        def setUp(self):
+            super(SampleTestFixture, self).setUp()
+            self.useFixture(self.package)
+            testtools.__path__.append(self.package.base)
+            self.addCleanup(testtools.__path__.remove, self.package.base)
+
+
+class TestRun(TestCase):
+
+    def test_run_list(self):
+        if fixtures is None:
+            self.skipTest("Need fixtures")
+        package = self.useFixture(SampleTestFixture())
+        out = StringIO.StringIO()
+        run.main(['prog', '-l', 'testtools.runexample.test_suite'], out)
+        self.assertEqual("""testtools.runexample.TestFoo.test_bar
+testtools.runexample.TestFoo.test_quux
+""", out.getvalue())
+
+    def test_run_load_list(self):
+        if fixtures is None:
+            self.skipTest("Need fixtures")
+        package = self.useFixture(SampleTestFixture())
+        out = StringIO.StringIO()
+        # We load two tests - one that exists and one that doesn't, and we
+        # should get the one that exists and neither the one that doesn't nor
+        # the unmentioned one that does.
+        tempdir = self.useFixture(fixtures.TempDir())
+        tempname = tempdir.path + '/tests.list'
+        f = open(tempname, 'wb')
+        try:
+            f.write("""
+testtools.runexample.TestFoo.test_bar
+testtools.runexample.missingtest
+""")
+        finally:
+            f.close()
+        run.main(['prog', '-l', '--load-list', tempname,
+            'testtools.runexample.test_suite'], out)
+        self.assertEqual("""testtools.runexample.TestFoo.test_bar
+""", out.getvalue())
+
+def test_suite():
+    from unittest import TestLoader
+    return TestLoader().loadTestsFromName(__name__)
index a4c0a728b1f4a8ccf0bce0f85706ef86e52fed7e..02863ac6fd2c8c4abd91ecb1eda0d4a9eda69f7e 100644 (file)
@@ -1,13 +1,15 @@
-# Copyright (c) 2009 Jonathan M. Lange. See LICENSE for details.
+# Copyright (c) 2009-2010 Jonathan M. Lange. See LICENSE for details.
 
 """Tests for the RunTest single test execution logic."""
 
 from testtools import (
     ExtendedToOriginalDecorator,
+    run_test_with,
     RunTest,
     TestCase,
     TestResult,
     )
+from testtools.matchers import MatchesException, Is, Raises
 from testtools.tests.helpers import ExtendedTestResult
 
 
@@ -62,7 +64,8 @@ class TestRunTest(TestCase):
             raise KeyboardInterrupt("yo")
         run = RunTest(case, None)
         run.result = ExtendedTestResult()
-        self.assertRaises(KeyboardInterrupt, run._run_user, raises)
+        self.assertThat(lambda: run._run_user(raises),
+            Raises(MatchesException(KeyboardInterrupt)))
         self.assertEqual([], run.result._events)
 
     def test__run_user_calls_onException(self):
@@ -107,7 +110,8 @@ class TestRunTest(TestCase):
             log.append((result, err))
         run = RunTest(case, [(ValueError, log_exc)])
         run.result = ExtendedTestResult()
-        self.assertRaises(KeyError, run._run_user, raises)
+        self.assertThat(lambda: run._run_user(raises),
+            Raises(MatchesException(KeyError)))
         self.assertEqual([], run.result._events)
         self.assertEqual([], log)
 
@@ -126,7 +130,8 @@ class TestRunTest(TestCase):
             log.append((result, err))
         run = RunTest(case, [(ValueError, log_exc)])
         run.result = ExtendedTestResult()
-        self.assertRaises(ValueError, run._run_user, raises)
+        self.assertThat(lambda: run._run_user(raises),
+            Raises(MatchesException(ValueError)))
         self.assertEqual([], run.result._events)
         self.assertEqual([], log)
 
@@ -169,13 +174,127 @@ class TestRunTest(TestCase):
             raise Exception("foo")
         run = RunTest(case, lambda x: x)
         run._run_core = inner
-        self.assertRaises(Exception, run.run, result)
+        self.assertThat(lambda: run.run(result),
+            Raises(MatchesException(Exception("foo"))))
         self.assertEqual([
             ('startTest', case),
             ('stopTest', case),
             ], result._events)
 
 
+class CustomRunTest(RunTest):
+
+    marker = object()
+
+    def run(self, result=None):
+        return self.marker
+
+
+class TestTestCaseSupportForRunTest(TestCase):
+
+    def test_pass_custom_run_test(self):
+        class SomeCase(TestCase):
+            def test_foo(self):
+                pass
+        result = TestResult()
+        case = SomeCase('test_foo', runTest=CustomRunTest)
+        from_run_test = case.run(result)
+        self.assertThat(from_run_test, Is(CustomRunTest.marker))
+
+    def test_default_is_runTest_class_variable(self):
+        class SomeCase(TestCase):
+            run_tests_with = CustomRunTest
+            def test_foo(self):
+                pass
+        result = TestResult()
+        case = SomeCase('test_foo')
+        from_run_test = case.run(result)
+        self.assertThat(from_run_test, Is(CustomRunTest.marker))
+
+    def test_constructor_argument_overrides_class_variable(self):
+        # If a 'runTest' argument is passed to the test's constructor, that
+        # overrides the class variable.
+        marker = object()
+        class DifferentRunTest(RunTest):
+            def run(self, result=None):
+                return marker
+        class SomeCase(TestCase):
+            run_tests_with = CustomRunTest
+            def test_foo(self):
+                pass
+        result = TestResult()
+        case = SomeCase('test_foo', runTest=DifferentRunTest)
+        from_run_test = case.run(result)
+        self.assertThat(from_run_test, Is(marker))
+
+    def test_decorator_for_run_test(self):
+        # Individual test methods can be marked as needing a special runner.
+        class SomeCase(TestCase):
+            @run_test_with(CustomRunTest)
+            def test_foo(self):
+                pass
+        result = TestResult()
+        case = SomeCase('test_foo')
+        from_run_test = case.run(result)
+        self.assertThat(from_run_test, Is(CustomRunTest.marker))
+
+    def test_extended_decorator_for_run_test(self):
+        # Individual test methods can be marked as needing a special runner.
+        # Extra arguments can be passed to the decorator which will then be
+        # passed on to the RunTest object.
+        marker = object()
+        class FooRunTest(RunTest):
+            def __init__(self, case, handlers=None, bar=None):
+                super(FooRunTest, self).__init__(case, handlers)
+                self.bar = bar
+            def run(self, result=None):
+                return self.bar
+        class SomeCase(TestCase):
+            @run_test_with(FooRunTest, bar=marker)
+            def test_foo(self):
+                pass
+        result = TestResult()
+        case = SomeCase('test_foo')
+        from_run_test = case.run(result)
+        self.assertThat(from_run_test, Is(marker))
+
+    def test_works_as_inner_decorator(self):
+        # Even if run_test_with is the innermost decorator, it will be
+        # respected.
+        def wrapped(function):
+            """Silly, trivial decorator."""
+            def decorated(*args, **kwargs):
+                return function(*args, **kwargs)
+            decorated.__name__ = function.__name__
+            decorated.__dict__.update(function.__dict__)
+            return decorated
+        class SomeCase(TestCase):
+            @wrapped
+            @run_test_with(CustomRunTest)
+            def test_foo(self):
+                pass
+        result = TestResult()
+        case = SomeCase('test_foo')
+        from_run_test = case.run(result)
+        self.assertThat(from_run_test, Is(CustomRunTest.marker))
+
+    def test_constructor_overrides_decorator(self):
+        # If a 'runTest' argument is passed to the test's constructor, that
+        # overrides the decorator.
+        marker = object()
+        class DifferentRunTest(RunTest):
+            def run(self, result=None):
+                return marker
+        class SomeCase(TestCase):
+            @run_test_with(CustomRunTest)
+            def test_foo(self):
+                pass
+        result = TestResult()
+        case = SomeCase('test_foo', runTest=DifferentRunTest)
+        from_run_test = case.run(result)
+        self.assertThat(from_run_test, Is(marker))
+
+
 def test_suite():
     from unittest import TestLoader
     return TestLoader().loadTestsFromName(__name__)
diff --git a/lib/testtools/testtools/tests/test_spinner.py b/lib/testtools/testtools/tests/test_spinner.py
new file mode 100644 (file)
index 0000000..f898956
--- /dev/null
@@ -0,0 +1,325 @@
+# Copyright (c) 2010 Jonathan M. Lange. See LICENSE for details.
+
+"""Tests for the evil Twisted reactor-spinning we do."""
+
+import os
+import signal
+
+from testtools import (
+    skipIf,
+    TestCase,
+    )
+from testtools.helpers import try_import
+from testtools.matchers import (
+    Equals,
+    Is,
+    MatchesException,
+    Raises,
+    )
+
+_spinner = try_import('testtools._spinner')
+
+defer = try_import('twisted.internet.defer')
+Failure = try_import('twisted.python.failure.Failure')
+
+
+class NeedsTwistedTestCase(TestCase):
+
+    def setUp(self):
+        super(NeedsTwistedTestCase, self).setUp()
+        if defer is None or Failure is None:
+            self.skipTest("Need Twisted to run")
+
+
+class TestNotReentrant(NeedsTwistedTestCase):
+
+    def test_not_reentrant(self):
+        # A function decorated as not being re-entrant will raise a
+        # _spinner.ReentryError if it is called while it is running.
+        calls = []
+        @_spinner.not_reentrant
+        def log_something():
+            calls.append(None)
+            if len(calls) < 5:
+                log_something()
+        self.assertThat(
+            log_something, Raises(MatchesException(_spinner.ReentryError)))
+        self.assertEqual(1, len(calls))
+
+    def test_deeper_stack(self):
+        calls = []
+        @_spinner.not_reentrant
+        def g():
+            calls.append(None)
+            if len(calls) < 5:
+                f()
+        @_spinner.not_reentrant
+        def f():
+            calls.append(None)
+            if len(calls) < 5:
+                g()
+        self.assertThat(f, Raises(MatchesException(_spinner.ReentryError)))
+        self.assertEqual(2, len(calls))
+
+
+class TestExtractResult(NeedsTwistedTestCase):
+
+    def test_not_fired(self):
+        # _spinner.extract_result raises _spinner.DeferredNotFired if it's
+        # given a Deferred that has not fired.
+        self.assertThat(lambda:_spinner.extract_result(defer.Deferred()),
+            Raises(MatchesException(_spinner.DeferredNotFired)))
+
+    def test_success(self):
+        # _spinner.extract_result returns the value of the Deferred if it has
+        # fired successfully.
+        marker = object()
+        d = defer.succeed(marker)
+        self.assertThat(_spinner.extract_result(d), Equals(marker))
+
+    def test_failure(self):
+        # _spinner.extract_result raises the failure's exception if it's given
+        # a Deferred that is failing.
+        try:
+            1/0
+        except ZeroDivisionError:
+            f = Failure()
+        d = defer.fail(f)
+        self.assertThat(lambda:_spinner.extract_result(d),
+            Raises(MatchesException(ZeroDivisionError)))
+
+
+class TestTrapUnhandledErrors(NeedsTwistedTestCase):
+
+    def test_no_deferreds(self):
+        marker = object()
+        result, errors = _spinner.trap_unhandled_errors(lambda: marker)
+        self.assertEqual([], errors)
+        self.assertIs(marker, result)
+
+    def test_unhandled_error(self):
+        failures = []
+        def make_deferred_but_dont_handle():
+            try:
+                1/0
+            except ZeroDivisionError:
+                f = Failure()
+                failures.append(f)
+                defer.fail(f)
+        result, errors = _spinner.trap_unhandled_errors(
+            make_deferred_but_dont_handle)
+        self.assertIs(None, result)
+        self.assertEqual(failures, [error.failResult for error in errors])
+
+
+class TestRunInReactor(NeedsTwistedTestCase):
+
+    def make_reactor(self):
+        from twisted.internet import reactor
+        return reactor
+
+    def make_spinner(self, reactor=None):
+        if reactor is None:
+            reactor = self.make_reactor()
+        return _spinner.Spinner(reactor)
+
+    def make_timeout(self):
+        return 0.01
+
+    def test_function_called(self):
+        # run_in_reactor actually calls the function given to it.
+        calls = []
+        marker = object()
+        self.make_spinner().run(self.make_timeout(), calls.append, marker)
+        self.assertThat(calls, Equals([marker]))
+
+    def test_return_value_returned(self):
+        # run_in_reactor returns the value returned by the function given to
+        # it.
+        marker = object()
+        result = self.make_spinner().run(self.make_timeout(), lambda: marker)
+        self.assertThat(result, Is(marker))
+
+    def test_exception_reraised(self):
+        # If the given function raises an error, run_in_reactor re-raises that
+        # error.
+        self.assertThat(
+            lambda:self.make_spinner().run(self.make_timeout(), lambda: 1/0),
+            Raises(MatchesException(ZeroDivisionError)))
+
+    def test_keyword_arguments(self):
+        # run_in_reactor passes keyword arguments on.
+        calls = []
+        function = lambda *a, **kw: calls.extend([a, kw])
+        self.make_spinner().run(self.make_timeout(), function, foo=42)
+        self.assertThat(calls, Equals([(), {'foo': 42}]))
+
+    def test_not_reentrant(self):
+        # run_in_reactor raises an error if it is called inside another call
+        # to run_in_reactor.
+        spinner = self.make_spinner()
+        self.assertThat(lambda: spinner.run(
+            self.make_timeout(), spinner.run, self.make_timeout(),
+            lambda: None), Raises(MatchesException(_spinner.ReentryError)))
+
+    def test_deferred_value_returned(self):
+        # If the given function returns a Deferred, run_in_reactor returns the
+        # value in the Deferred at the end of the callback chain.
+        marker = object()
+        result = self.make_spinner().run(
+            self.make_timeout(), lambda: defer.succeed(marker))
+        self.assertThat(result, Is(marker))
+
+    def test_preserve_signal_handler(self):
+        signals = ['SIGINT', 'SIGTERM', 'SIGCHLD']
+        signals = filter(
+            None, (getattr(signal, name, None) for name in signals))
+        for sig in signals:
+            self.addCleanup(signal.signal, sig, signal.getsignal(sig))
+        new_hdlrs = list(lambda *a: None for _ in signals)
+        for sig, hdlr in zip(signals, new_hdlrs):
+            signal.signal(sig, hdlr)
+        spinner = self.make_spinner()
+        spinner.run(self.make_timeout(), lambda: None)
+        self.assertEqual(new_hdlrs, map(signal.getsignal, signals))
+
+    def test_timeout(self):
+        # If the function takes too long to run, we raise a
+        # _spinner.TimeoutError.
+        timeout = self.make_timeout()
+        self.assertThat(
+            lambda:self.make_spinner().run(timeout, lambda: defer.Deferred()),
+            Raises(MatchesException(_spinner.TimeoutError)))
+
+    def test_no_junk_by_default(self):
+        # If the reactor hasn't spun yet, then there cannot be any junk.
+        spinner = self.make_spinner()
+        self.assertThat(spinner.get_junk(), Equals([]))
+
+    def test_clean_do_nothing(self):
+        # If there's nothing going on in the reactor, then clean does nothing
+        # and returns an empty list.
+        spinner = self.make_spinner()
+        result = spinner._clean()
+        self.assertThat(result, Equals([]))
+
+    def test_clean_delayed_call(self):
+        # If there's a delayed call in the reactor, then clean cancels it and
+        # returns an empty list.
+        reactor = self.make_reactor()
+        spinner = self.make_spinner(reactor)
+        call = reactor.callLater(10, lambda: None)
+        results = spinner._clean()
+        self.assertThat(results, Equals([call]))
+        self.assertThat(call.active(), Equals(False))
+
+    def test_clean_delayed_call_cancelled(self):
+        # If there's a delayed call that's just been cancelled, then it's no
+        # longer there.
+        reactor = self.make_reactor()
+        spinner = self.make_spinner(reactor)
+        call = reactor.callLater(10, lambda: None)
+        call.cancel()
+        results = spinner._clean()
+        self.assertThat(results, Equals([]))
+
+    def test_clean_selectables(self):
+        # If there's still a selectable (e.g. a listening socket), then
+        # clean() removes it from the reactor's registry.
+        #
+        # Note that the socket is left open. This emulates a bug in trial.
+        from twisted.internet.protocol import ServerFactory
+        reactor = self.make_reactor()
+        spinner = self.make_spinner(reactor)
+        port = reactor.listenTCP(0, ServerFactory())
+        spinner.run(self.make_timeout(), lambda: None)
+        results = spinner.get_junk()
+        self.assertThat(results, Equals([port]))
+
+    def test_clean_running_threads(self):
+        import threading
+        import time
+        current_threads = list(threading.enumerate())
+        reactor = self.make_reactor()
+        timeout = self.make_timeout()
+        spinner = self.make_spinner(reactor)
+        spinner.run(timeout, reactor.callInThread, time.sleep, timeout / 2.0)
+        self.assertThat(list(threading.enumerate()), Equals(current_threads))
+
+    def test_leftover_junk_available(self):
+        # If 'run' is given a function that leaves the reactor dirty in some
+        # way, 'run' will clean up the reactor and then store information
+        # about the junk. This information can be got using get_junk.
+        from twisted.internet.protocol import ServerFactory
+        reactor = self.make_reactor()
+        spinner = self.make_spinner(reactor)
+        port = spinner.run(
+            self.make_timeout(), reactor.listenTCP, 0, ServerFactory())
+        self.assertThat(spinner.get_junk(), Equals([port]))
+
+    def test_will_not_run_with_previous_junk(self):
+        # If 'run' is called and there's still junk in the spinner's junk
+        # list, then the spinner will refuse to run.
+        from twisted.internet.protocol import ServerFactory
+        reactor = self.make_reactor()
+        spinner = self.make_spinner(reactor)
+        timeout = self.make_timeout()
+        spinner.run(timeout, reactor.listenTCP, 0, ServerFactory())
+        self.assertThat(lambda: spinner.run(timeout, lambda: None),
+            Raises(MatchesException(_spinner.StaleJunkError)))
+
+    def test_clear_junk_clears_previous_junk(self):
+        # If 'run' is called and there's still junk in the spinner's junk
+        # list, then the spinner will refuse to run.
+        from twisted.internet.protocol import ServerFactory
+        reactor = self.make_reactor()
+        spinner = self.make_spinner(reactor)
+        timeout = self.make_timeout()
+        port = spinner.run(timeout, reactor.listenTCP, 0, ServerFactory())
+        junk = spinner.clear_junk()
+        self.assertThat(junk, Equals([port]))
+        self.assertThat(spinner.get_junk(), Equals([]))
+
+    @skipIf(os.name != "posix", "Sending SIGINT with os.kill is posix only")
+    def test_sigint_raises_no_result_error(self):
+        # If we get a SIGINT during a run, we raise _spinner.NoResultError.
+        SIGINT = getattr(signal, 'SIGINT', None)
+        if not SIGINT:
+            self.skipTest("SIGINT not available")
+        reactor = self.make_reactor()
+        spinner = self.make_spinner(reactor)
+        timeout = self.make_timeout()
+        reactor.callLater(timeout, os.kill, os.getpid(), SIGINT)
+        self.assertThat(lambda:spinner.run(timeout * 5, defer.Deferred),
+            Raises(MatchesException(_spinner.NoResultError)))
+        self.assertEqual([], spinner._clean())
+
+    @skipIf(os.name != "posix", "Sending SIGINT with os.kill is posix only")
+    def test_sigint_raises_no_result_error_second_time(self):
+        # If we get a SIGINT during a run, we raise _spinner.NoResultError.
+        # This test is exactly the same as test_sigint_raises_no_result_error,
+        # and exists to make sure we haven't futzed with state.
+        self.test_sigint_raises_no_result_error()
+
+    @skipIf(os.name != "posix", "Sending SIGINT with os.kill is posix only")
+    def test_fast_sigint_raises_no_result_error(self):
+        # If we get a SIGINT during a run, we raise _spinner.NoResultError.
+        SIGINT = getattr(signal, 'SIGINT', None)
+        if not SIGINT:
+            self.skipTest("SIGINT not available")
+        reactor = self.make_reactor()
+        spinner = self.make_spinner(reactor)
+        timeout = self.make_timeout()
+        reactor.callWhenRunning(os.kill, os.getpid(), SIGINT)
+        self.assertThat(lambda:spinner.run(timeout * 5, defer.Deferred),
+            Raises(MatchesException(_spinner.NoResultError)))
+        self.assertEqual([], spinner._clean())
+
+    @skipIf(os.name != "posix", "Sending SIGINT with os.kill is posix only")
+    def test_fast_sigint_raises_no_result_error_second_time(self):
+        self.test_fast_sigint_raises_no_result_error()
+
+
+def test_suite():
+    from unittest import TestLoader
+    return TestLoader().loadTestsFromName(__name__)
index 1a19440069259c0fbaccf3cea498a12f62b62a6b..a0e090d9210946c2a84da58c97ff535e9355f31e 100644 (file)
@@ -6,10 +6,6 @@ __metaclass__ = type
 
 import codecs
 import datetime
-try:
-    from StringIO import StringIO
-except ImportError:
-    from io import StringIO
 import doctest
 import os
 import shutil
@@ -26,6 +22,7 @@ from testtools import (
     TextTestResult,
     ThreadsafeForwardingResult,
     testresult,
+    try_imports,
     )
 from testtools.compat import (
     _b,
@@ -34,8 +31,13 @@ from testtools.compat import (
     _u,
     str_is_unicode,
     )
-from testtools.content import Content, ContentType
-from testtools.matchers import DocTestMatches
+from testtools.content import Content
+from testtools.content_type import ContentType, UTF8_TEXT
+from testtools.matchers import (
+    DocTestMatches,
+    MatchesException,
+    Raises,
+    )
 from testtools.tests.helpers import (
     LoggingResult,
     Python26TestResult,
@@ -44,81 +46,198 @@ from testtools.tests.helpers import (
     an_exc_info
     )
 
+StringIO = try_imports(['StringIO.StringIO', 'io.StringIO'])
 
-class TestTestResultContract(TestCase):
-    """Tests for the contract of TestResults."""
+
+class Python26Contract(object):
+
+    def test_fresh_result_is_successful(self):
+        # A result is considered successful before any tests are run.
+        result = self.makeResult()
+        self.assertTrue(result.wasSuccessful())
+
+    def test_addError_is_failure(self):
+        # addError fails the test run.
+        result = self.makeResult()
+        result.startTest(self)
+        result.addError(self, an_exc_info)
+        result.stopTest(self)
+        self.assertFalse(result.wasSuccessful())
+
+    def test_addFailure_is_failure(self):
+        # addFailure fails the test run.
+        result = self.makeResult()
+        result.startTest(self)
+        result.addFailure(self, an_exc_info)
+        result.stopTest(self)
+        self.assertFalse(result.wasSuccessful())
+
+    def test_addSuccess_is_success(self):
+        # addSuccess does not fail the test run.
+        result = self.makeResult()
+        result.startTest(self)
+        result.addSuccess(self)
+        result.stopTest(self)
+        self.assertTrue(result.wasSuccessful())
+
+
+class Python27Contract(Python26Contract):
 
     def test_addExpectedFailure(self):
         # Calling addExpectedFailure(test, exc_info) completes ok.
         result = self.makeResult()
+        result.startTest(self)
         result.addExpectedFailure(self, an_exc_info)
 
+    def test_addExpectedFailure_is_success(self):
+        # addExpectedFailure does not fail the test run.
+        result = self.makeResult()
+        result.startTest(self)
+        result.addExpectedFailure(self, an_exc_info)
+        result.stopTest(self)
+        self.assertTrue(result.wasSuccessful())
+
+    def test_addSkipped(self):
+        # Calling addSkip(test, reason) completes ok.
+        result = self.makeResult()
+        result.startTest(self)
+        result.addSkip(self, _u("Skipped for some reason"))
+
+    def test_addSkip_is_success(self):
+        # addSkip does not fail the test run.
+        result = self.makeResult()
+        result.startTest(self)
+        result.addSkip(self, _u("Skipped for some reason"))
+        result.stopTest(self)
+        self.assertTrue(result.wasSuccessful())
+
+    def test_addUnexpectedSuccess(self):
+        # Calling addUnexpectedSuccess(test) completes ok.
+        result = self.makeResult()
+        result.startTest(self)
+        result.addUnexpectedSuccess(self)
+
+    def test_addUnexpectedSuccess_was_successful(self):
+        # addUnexpectedSuccess does not fail the test run in Python 2.7.
+        result = self.makeResult()
+        result.startTest(self)
+        result.addUnexpectedSuccess(self)
+        result.stopTest(self)
+        self.assertTrue(result.wasSuccessful())
+
+    def test_startStopTestRun(self):
+        # Calling startTestRun completes ok.
+        result = self.makeResult()
+        result.startTestRun()
+        result.stopTestRun()
+
+
+class DetailsContract(Python27Contract):
+    """Tests for the contract of TestResults."""
+
     def test_addExpectedFailure_details(self):
         # Calling addExpectedFailure(test, details=xxx) completes ok.
         result = self.makeResult()
+        result.startTest(self)
         result.addExpectedFailure(self, details={})
 
     def test_addError_details(self):
         # Calling addError(test, details=xxx) completes ok.
         result = self.makeResult()
+        result.startTest(self)
         result.addError(self, details={})
 
     def test_addFailure_details(self):
         # Calling addFailure(test, details=xxx) completes ok.
         result = self.makeResult()
+        result.startTest(self)
         result.addFailure(self, details={})
 
-    def test_addSkipped(self):
-        # Calling addSkip(test, reason) completes ok.
-        result = self.makeResult()
-        result.addSkip(self, _u("Skipped for some reason"))
-
     def test_addSkipped_details(self):
         # Calling addSkip(test, reason) completes ok.
         result = self.makeResult()
+        result.startTest(self)
         result.addSkip(self, details={})
 
-    def test_addUnexpectedSuccess(self):
-        # Calling addUnexpectedSuccess(test) completes ok.
-        result = self.makeResult()
-        result.addUnexpectedSuccess(self)
-
     def test_addUnexpectedSuccess_details(self):
         # Calling addUnexpectedSuccess(test) completes ok.
         result = self.makeResult()
+        result.startTest(self)
         result.addUnexpectedSuccess(self, details={})
 
     def test_addSuccess_details(self):
         # Calling addSuccess(test) completes ok.
         result = self.makeResult()
+        result.startTest(self)
         result.addSuccess(self, details={})
 
-    def test_startStopTestRun(self):
-        # Calling startTestRun completes ok.
+
+class FallbackContract(DetailsContract):
+    """When we fallback we take our policy choice to map calls.
+
+    For instance, we map unexpectedSuccess to an error code, not to success.
+    """
+
+    def test_addUnexpectedSuccess_was_successful(self):
+        # addUnexpectedSuccess fails test run in testtools.
         result = self.makeResult()
+        result.startTest(self)
+        result.addUnexpectedSuccess(self)
+        result.stopTest(self)
+        self.assertFalse(result.wasSuccessful())
+
+
+class StartTestRunContract(FallbackContract):
+    """Defines the contract for testtools policy choices.
+    
+    That is things which are not simply extensions to unittest but choices we
+    have made differently.
+    """
+
+    def test_startTestRun_resets_unexpected_success(self):
+        result = self.makeResult()
+        result.startTest(self)
+        result.addUnexpectedSuccess(self)
+        result.stopTest(self)
         result.startTestRun()
-        result.stopTestRun()
+        self.assertTrue(result.wasSuccessful())
 
+    def test_startTestRun_resets_failure(self):
+        result = self.makeResult()
+        result.startTest(self)
+        result.addFailure(self, an_exc_info)
+        result.stopTest(self)
+        result.startTestRun()
+        self.assertTrue(result.wasSuccessful())
 
-class TestTestResultContract(TestTestResultContract):
+    def test_startTestRun_resets_errors(self):
+        result = self.makeResult()
+        result.startTest(self)
+        result.addError(self, an_exc_info)
+        result.stopTest(self)
+        result.startTestRun()
+        self.assertTrue(result.wasSuccessful())
+
+
+class TestTestResultContract(TestCase, StartTestRunContract):
 
     def makeResult(self):
         return TestResult()
 
 
-class TestMultiTestresultContract(TestTestResultContract):
+class TestMultiTestResultContract(TestCase, StartTestRunContract):
 
     def makeResult(self):
         return MultiTestResult(TestResult(), TestResult())
 
 
-class TestTextTestResultContract(TestTestResultContract):
+class TestTextTestResultContract(TestCase, StartTestRunContract):
 
     def makeResult(self):
         return TextTestResult(StringIO())
 
 
-class TestThreadSafeForwardingResultContract(TestTestResultContract):
+class TestThreadSafeForwardingResultContract(TestCase, StartTestRunContract):
 
     def makeResult(self):
         result_semaphore = threading.Semaphore(1)
@@ -126,6 +245,36 @@ class TestThreadSafeForwardingResultContract(TestTestResultContract):
         return ThreadsafeForwardingResult(target, result_semaphore)
 
 
+class TestExtendedTestResultContract(TestCase, StartTestRunContract):
+
+    def makeResult(self):
+        return ExtendedTestResult()
+
+
+class TestPython26TestResultContract(TestCase, Python26Contract):
+
+    def makeResult(self):
+        return Python26TestResult()
+
+
+class TestAdaptedPython26TestResultContract(TestCase, FallbackContract):
+
+    def makeResult(self):
+        return ExtendedToOriginalDecorator(Python26TestResult())
+
+
+class TestPython27TestResultContract(TestCase, Python27Contract):
+
+    def makeResult(self):
+        return Python27TestResult()
+
+
+class TestAdaptedPython27TestResultContract(TestCase, DetailsContract):
+
+    def makeResult(self):
+        return ExtendedToOriginalDecorator(Python27TestResult())
+
+
 class TestTestResult(TestCase):
     """Tests for `TestResult`."""
 
@@ -295,6 +444,12 @@ class TestTextTestResult(TestCase):
                 self.fail("yo!")
         return Test("failed")
 
+    def make_unexpectedly_successful_test(self):
+        class Test(TestCase):
+            def succeeded(self):
+                self.expectFailure("yo!", lambda: None)
+        return Test("succeeded")
+
     def make_test(self):
         class Test(TestCase):
             def test(self):
@@ -380,9 +535,18 @@ class TestTextTestResult(TestCase):
         self.assertThat(self.getvalue(),
             DocTestMatches("...\n\nFAILED (failures=1)\n", doctest.ELLIPSIS))
 
+    def test_stopTestRun_not_successful_unexpected_success(self):
+        test = self.make_unexpectedly_successful_test()
+        self.result.startTestRun()
+        test.run(self.result)
+        self.result.stopTestRun()
+        self.assertThat(self.getvalue(),
+            DocTestMatches("...\n\nFAILED (failures=1)\n", doctest.ELLIPSIS))
+
     def test_stopTestRun_shows_details(self):
         self.result.startTestRun()
         self.make_erroring_test().run(self.result)
+        self.make_unexpectedly_successful_test().run(self.result)
         self.make_failing_test().run(self.result)
         self.reset_output()
         self.result.stopTestRun()
@@ -394,9 +558,9 @@ Text attachment: traceback
 ------------
 Traceback (most recent call last):
   File "...testtools...runtest.py", line ..., in _run_user...
-    return fn(*args)
+    return fn(*args, **kwargs)
   File "...testtools...testcase.py", line ..., in _run_test_method
-    testMethod()
+    return self._get_test_method()()
   File "...testtools...tests...test_testresult.py", line ..., in error
     1/0
 ZeroDivisionError:... divi... by zero...
@@ -408,18 +572,21 @@ Text attachment: traceback
 ------------
 Traceback (most recent call last):
   File "...testtools...runtest.py", line ..., in _run_user...
-    return fn(*args)
+    return fn(*args, **kwargs)
   File "...testtools...testcase.py", line ..., in _run_test_method
-    testMethod()
+    return self._get_test_method()()
   File "...testtools...tests...test_testresult.py", line ..., in failed
     self.fail("yo!")
 AssertionError: yo!
 ------------
-...""", doctest.ELLIPSIS))
+======================================================================
+UNEXPECTED SUCCESS: testtools.tests.test_testresult.Test.succeeded
+----------------------------------------------------------------------
+...""", doctest.ELLIPSIS | doctest.REPORT_NDIFF))
 
 
 class TestThreadSafeForwardingResult(TestWithFakeExceptions):
-    """Tests for `MultiTestResult`."""
+    """Tests for `TestThreadSafeForwardingResult`."""
 
     def setUp(self):
         TestWithFakeExceptions.setUp(self)
@@ -452,22 +619,51 @@ class TestThreadSafeForwardingResult(TestWithFakeExceptions):
     def test_forwarding_methods(self):
         # error, failure, skip and success are forwarded in batches.
         exc_info1 = self.makeExceptionInfo(RuntimeError, 'error')
+        starttime1 = datetime.datetime.utcfromtimestamp(1.489)
+        endtime1 = datetime.datetime.utcfromtimestamp(51.476)
+        self.result1.time(starttime1)
+        self.result1.startTest(self)
+        self.result1.time(endtime1)
         self.result1.addError(self, exc_info1)
         exc_info2 = self.makeExceptionInfo(AssertionError, 'failure')
+        starttime2 = datetime.datetime.utcfromtimestamp(2.489)
+        endtime2 = datetime.datetime.utcfromtimestamp(3.476)
+        self.result1.time(starttime2)
+        self.result1.startTest(self)
+        self.result1.time(endtime2)
         self.result1.addFailure(self, exc_info2)
         reason = _u("Skipped for some reason")
+        starttime3 = datetime.datetime.utcfromtimestamp(4.489)
+        endtime3 = datetime.datetime.utcfromtimestamp(5.476)
+        self.result1.time(starttime3)
+        self.result1.startTest(self)
+        self.result1.time(endtime3)
         self.result1.addSkip(self, reason)
+        starttime4 = datetime.datetime.utcfromtimestamp(6.489)
+        endtime4 = datetime.datetime.utcfromtimestamp(7.476)
+        self.result1.time(starttime4)
+        self.result1.startTest(self)
+        self.result1.time(endtime4)
         self.result1.addSuccess(self)
-        self.assertEqual([('startTest', self),
+        self.assertEqual([
+            ('time', starttime1),
+            ('startTest', self),
+            ('time', endtime1),
             ('addError', self, exc_info1),
             ('stopTest', self),
+            ('time', starttime2),
             ('startTest', self),
+            ('time', endtime2),
             ('addFailure', self, exc_info2),
             ('stopTest', self),
+            ('time', starttime3),
             ('startTest', self),
+            ('time', endtime3),
             ('addSkip', self, reason),
             ('stopTest', self),
+            ('time', starttime4),
             ('startTest', self),
+            ('time', endtime4),
             ('addSuccess', self),
             ('stopTest', self),
             ], self.target._events)
@@ -536,6 +732,14 @@ class TestExtendedToOriginalResultDecoratorBase(TestCase):
         getattr(self.converter, outcome)(self, details=details)
         self.assertEqual([(outcome, self, err_str)], self.result._events)
 
+    def check_outcome_details_to_arg(self, outcome, arg, extra_detail=None):
+        """Call an outcome with a details dict to have an arg extracted."""
+        details, _ = self.get_details_and_string()
+        if extra_detail:
+            details.update(extra_detail)
+        getattr(self.converter, outcome)(self, details=details)
+        self.assertEqual([(outcome, self, arg)], self.result._events)
+
     def check_outcome_exc_info(self, outcome, expected=None):
         """Check that calling a legacy outcome still works."""
         # calling some outcome with the legacy exc_info style api (no keyword
@@ -713,8 +917,9 @@ class TestExtendedToOriginalAddError(TestExtendedToOriginalResultDecoratorBase):
 
     def test_outcome__no_details(self):
         self.make_extended_result()
-        self.assertRaises(ValueError,
-            getattr(self.converter, self.outcome), self)
+        self.assertThat(
+            lambda: getattr(self.converter, self.outcome)(self),
+            Raises(MatchesException(ValueError)))
 
 
 class TestExtendedToOriginalAddFailure(
@@ -759,18 +964,24 @@ class TestExtendedToOriginalAddSkip(
         self.make_26_result()
         self.check_outcome_string_nothing(self.outcome, 'addSuccess')
 
-    def test_outcome_Extended_py27(self):
+    def test_outcome_Extended_py27_no_reason(self):
         self.make_27_result()
         self.check_outcome_details_to_string(self.outcome)
 
+    def test_outcome_Extended_py27_reason(self):
+        self.make_27_result()
+        self.check_outcome_details_to_arg(self.outcome, 'foo',
+            {'reason': Content(UTF8_TEXT, lambda:[_b('foo')])})
+
     def test_outcome_Extended_pyextended(self):
         self.make_extended_result()
         self.check_outcome_details(self.outcome)
 
     def test_outcome__no_details(self):
         self.make_extended_result()
-        self.assertRaises(ValueError,
-            getattr(self.converter, self.outcome), self)
+        self.assertThat(
+            lambda: getattr(self.converter, self.outcome)(self),
+            Raises(MatchesException(ValueError)))
 
 
 class TestExtendedToOriginalAddSuccess(
@@ -805,9 +1016,38 @@ class TestExtendedToOriginalAddSuccess(
 
 
 class TestExtendedToOriginalAddUnexpectedSuccess(
-    TestExtendedToOriginalAddSuccess):
+    TestExtendedToOriginalResultDecoratorBase):
 
     outcome = 'addUnexpectedSuccess'
+    expected = 'addFailure'
+
+    def test_outcome_Original_py26(self):
+        self.make_26_result()
+        getattr(self.converter, self.outcome)(self)
+        [event] = self.result._events
+        self.assertEqual((self.expected, self), event[:2])
+
+    def test_outcome_Original_py27(self):
+        self.make_27_result()
+        self.check_outcome_nothing(self.outcome)
+
+    def test_outcome_Original_pyextended(self):
+        self.make_extended_result()
+        self.check_outcome_nothing(self.outcome)
+
+    def test_outcome_Extended_py26(self):
+        self.make_26_result()
+        getattr(self.converter, self.outcome)(self)
+        [event] = self.result._events
+        self.assertEqual((self.expected, self), event[:2])
+
+    def test_outcome_Extended_py27(self):
+        self.make_27_result()
+        self.check_outcome_details_to_nothing(self.outcome)
+
+    def test_outcome_Extended_pyextended(self):
+        self.make_extended_result()
+        self.check_outcome_details(self.outcome)
 
 
 class TestExtendedToOriginalResultOtherAttributes(
@@ -1030,6 +1270,11 @@ class TestNonAsciiResults(TestCase):
             'SyntaxError: '
             ), textoutput)
 
+    def test_syntax_error_malformed(self):
+        """Syntax errors with bogus parameters should break anything"""
+        textoutput = self._test_external_case("raise SyntaxError(3, 2, 1)")
+        self.assertIn(self._as_output("\nSyntaxError: "), textoutput)
+
     def test_syntax_error_import_binary(self):
         """Importing a binary file shouldn't break SyntaxError formatting"""
         if sys.version_info < (2, 5):
index 3f2f02758f4dc71aa0b53cc85ef2b8ceb617bf0d..eeb8fd28119cb973285264bff98d284ab4895274 100644 (file)
@@ -4,6 +4,7 @@
 
 __metaclass__ = type
 
+import datetime
 import unittest
 
 from testtools import (
@@ -35,16 +36,12 @@ class TestConcurrentTestSuiteRun(TestCase):
         original_suite = unittest.TestSuite([test1, test2])
         suite = ConcurrentTestSuite(original_suite, self.split_suite)
         suite.run(result)
-        test1 = log[0][1]
+        # 0 is the timestamp for the first test starting.
+        test1 = log[1][1]
         test2 = log[-1][1]
         self.assertIsInstance(test1, Sample)
         self.assertIsInstance(test2, Sample)
         self.assertNotEqual(test1.id(), test2.id())
-        # We expect the start/outcome/stop to be grouped
-        expected = [('startTest', test1), ('addSuccess', test1),
-            ('stopTest', test1), ('startTest', test2), ('addSuccess', test2),
-            ('stopTest', test2)]
-        self.assertThat(log, Equals(expected))
 
     def split_suite(self, suite):
         tests = list(iterate_tests(suite))
index 8e253e63117cb5684b566d02ca71c50a615f113c..2845730f9f7c179016208b95a1d6f455fc4c2fdc 100644 (file)
@@ -20,6 +20,8 @@ from testtools import (
     )
 from testtools.matchers import (
     Equals,
+    MatchesException,
+    Raises,
     )
 from testtools.tests.helpers import (
     an_exc_info,
@@ -246,10 +248,9 @@ class TestAssertions(TestCase):
 
     def test_assertRaises_fails_when_different_error_raised(self):
         # assertRaises re-raises an exception that it didn't expect.
-        self.assertRaises(
-            ZeroDivisionError,
-            self.assertRaises,
-                RuntimeError, self.raiseError, ZeroDivisionError)
+        self.assertThat(lambda: self.assertRaises(RuntimeError,
+            self.raiseError, ZeroDivisionError),
+            Raises(MatchesException(ZeroDivisionError)))
 
     def test_assertRaises_returns_the_raised_exception(self):
         # assertRaises returns the exception object that was raised. This is
@@ -606,8 +607,8 @@ class TestAddCleanup(TestCase):
         def raiseKeyboardInterrupt():
             raise KeyboardInterrupt()
         self.test.addCleanup(raiseKeyboardInterrupt)
-        self.assertRaises(
-            KeyboardInterrupt, self.test.run, self.logging_result)
+        self.assertThat(lambda:self.test.run(self.logging_result),
+            Raises(MatchesException(KeyboardInterrupt)))
 
     def test_all_errors_from_MultipleExceptions_reported(self):
         # When a MultipleExceptions exception is caught, all the errors are
@@ -935,10 +936,12 @@ class TestSkipping(TestCase):
     """Tests for skipping of tests functionality."""
 
     def test_skip_causes_skipException(self):
-        self.assertRaises(self.skipException, self.skip, "Skip this test")
+        self.assertThat(lambda:self.skip("Skip this test"),
+            Raises(MatchesException(self.skipException)))
 
     def test_can_use_skipTest(self):
-        self.assertRaises(self.skipException, self.skipTest, "Skip this test")
+        self.assertThat(lambda:self.skipTest("Skip this test"),
+            Raises(MatchesException(self.skipException)))
 
     def test_skip_without_reason_works(self):
         class Test(TestCase):
@@ -964,8 +967,7 @@ class TestSkipping(TestCase):
         test.run(result)
         case = result._events[0][1]
         self.assertEqual([('startTest', case),
-            ('addSkip', case, "Text attachment: reason\n------------\n"
-             "skipping this test\n------------\n"), ('stopTest', case)],
+            ('addSkip', case, "skipping this test"), ('stopTest', case)],
             calls)
 
     def test_skipException_in_test_method_calls_result_addSkip(self):
@@ -977,8 +979,7 @@ class TestSkipping(TestCase):
         test.run(result)
         case = result._events[0][1]
         self.assertEqual([('startTest', case),
-            ('addSkip', case, "Text attachment: reason\n------------\n"
-             "skipping this test\n------------\n"), ('stopTest', case)],
+            ('addSkip', case, "skipping this test"), ('stopTest', case)],
             result._events)
 
     def test_skip__in_setup_with_old_result_object_calls_addSuccess(self):
@@ -1060,7 +1061,8 @@ class TestOnException(TestCase):
         class Case(TestCase):
             def method(self):
                 self.addOnException(events.index)
-                self.assertRaises(ValueError, self.onException, an_exc_info)
+                self.assertThat(lambda: self.onException(an_exc_info),
+                    Raises(MatchesException(ValueError)))
         case = Case("method")
         case.run()
         self.assertThat(events, Equals([]))