testtools: Update to latest upstream version.
[kai/samba.git] / lib / testtools / testtools / _spinner.py
1 # Copyright (c) 2010 testtools developers. See LICENSE for details.
2
3 """Evil reactor-spinning logic for running Twisted tests.
4
5 This code is highly experimental, liable to change and not to be trusted.  If
6 you couldn't write this yourself, you should not be using it.
7 """
8
9 __all__ = [
10     'DeferredNotFired',
11     'extract_result',
12     'NoResultError',
13     'not_reentrant',
14     'ReentryError',
15     'Spinner',
16     'StaleJunkError',
17     'TimeoutError',
18     'trap_unhandled_errors',
19     ]
20
21 import signal
22
23 from testtools.monkey import MonkeyPatcher
24
25 from twisted.internet import defer
26 from twisted.internet.base import DelayedCall
27 from twisted.internet.interfaces import IReactorThreads
28 from twisted.python.failure import Failure
29 from twisted.python.util import mergeFunctionMetadata
30
31
32 class ReentryError(Exception):
33     """Raised when we try to re-enter a function that forbids it."""
34
35     def __init__(self, function):
36         Exception.__init__(self,
37             "%r in not re-entrant but was called within a call to itself."
38             % (function,))
39
40
41 def not_reentrant(function, _calls={}):
42     """Decorates a function as not being re-entrant.
43
44     The decorated function will raise an error if called from within itself.
45     """
46     def decorated(*args, **kwargs):
47         if _calls.get(function, False):
48             raise ReentryError(function)
49         _calls[function] = True
50         try:
51             return function(*args, **kwargs)
52         finally:
53             _calls[function] = False
54     return mergeFunctionMetadata(function, decorated)
55
56
57 class DeferredNotFired(Exception):
58     """Raised when we extract a result from a Deferred that's not fired yet."""
59
60
61 def extract_result(deferred):
62     """Extract the result from a fired deferred.
63
64     It can happen that you have an API that returns Deferreds for
65     compatibility with Twisted code, but is in fact synchronous, i.e. the
66     Deferreds it returns have always fired by the time it returns.  In this
67     case, you can use this function to convert the result back into the usual
68     form for a synchronous API, i.e. the result itself or a raised exception.
69
70     It would be very bad form to use this as some way of checking if a
71     Deferred has fired.
72     """
73     failures = []
74     successes = []
75     deferred.addCallbacks(successes.append, failures.append)
76     if len(failures) == 1:
77         failures[0].raiseException()
78     elif len(successes) == 1:
79         return successes[0]
80     else:
81         raise DeferredNotFired("%r has not fired yet." % (deferred,))
82
83
84 def trap_unhandled_errors(function, *args, **kwargs):
85     """Run a function, trapping any unhandled errors in Deferreds.
86
87     Assumes that 'function' will have handled any errors in Deferreds by the
88     time it is complete.  This is almost never true of any Twisted code, since
89     you can never tell when someone has added an errback to a Deferred.
90
91     If 'function' raises, then don't bother doing any unhandled error
92     jiggery-pokery, since something horrible has probably happened anyway.
93
94     :return: A tuple of '(result, error)', where 'result' is the value
95         returned by 'function' and 'error' is a list of 'defer.DebugInfo'
96         objects that have unhandled errors in Deferreds.
97     """
98     real_DebugInfo = defer.DebugInfo
99     debug_infos = []
100     def DebugInfo():
101         info = real_DebugInfo()
102         debug_infos.append(info)
103         return info
104     defer.DebugInfo = DebugInfo
105     try:
106         result = function(*args, **kwargs)
107     finally:
108         defer.DebugInfo = real_DebugInfo
109     errors = []
110     for info in debug_infos:
111         if info.failResult is not None:
112             errors.append(info)
113             # Disable the destructor that logs to error. We are already
114             # catching the error here.
115             info.__del__ = lambda: None
116     return result, errors
117
118
119 class TimeoutError(Exception):
120     """Raised when run_in_reactor takes too long to run a function."""
121
122     def __init__(self, function, timeout):
123         Exception.__init__(self,
124             "%r took longer than %s seconds" % (function, timeout))
125
126
127 class NoResultError(Exception):
128     """Raised when the reactor has stopped but we don't have any result."""
129
130     def __init__(self):
131         Exception.__init__(self,
132             "Tried to get test's result from Deferred when no result is "
133             "available.  Probably means we received SIGINT or similar.")
134
135
136 class StaleJunkError(Exception):
137     """Raised when there's junk in the spinner from a previous run."""
138
139     def __init__(self, junk):
140         Exception.__init__(self,
141             "There was junk in the spinner from a previous run. "
142             "Use clear_junk() to clear it out: %r" % (junk,))
143
144
145 class Spinner(object):
146     """Spin the reactor until a function is done.
147
148     This class emulates the behaviour of twisted.trial in that it grotesquely
149     and horribly spins the Twisted reactor while a function is running, and
150     then kills the reactor when that function is complete and all the
151     callbacks in its chains are done.
152     """
153
154     _UNSET = object()
155
156     # Signals that we save and restore for each spin.
157     _PRESERVED_SIGNALS = [
158         'SIGINT',
159         'SIGTERM',
160         'SIGCHLD',
161         ]
162
163     # There are many APIs within Twisted itself where a Deferred fires but
164     # leaves cleanup work scheduled for the reactor to do.  Arguably, many of
165     # these are bugs.  As such, we provide a facility to iterate the reactor
166     # event loop a number of times after every call, in order to shake out
167     # these buggy-but-commonplace events.  The default is 0, because that is
168     # the ideal, and it actually works for many cases.
169     _OBLIGATORY_REACTOR_ITERATIONS = 0
170
171     def __init__(self, reactor, debug=False):
172         """Construct a Spinner.
173
174         :param reactor: A Twisted reactor.
175         :param debug: Whether or not to enable Twisted's debugging.  Defaults
176             to False.
177         """
178         self._reactor = reactor
179         self._timeout_call = None
180         self._success = self._UNSET
181         self._failure = self._UNSET
182         self._saved_signals = []
183         self._junk = []
184         self._debug = debug
185
186     def _cancel_timeout(self):
187         if self._timeout_call:
188             self._timeout_call.cancel()
189
190     def _get_result(self):
191         if self._failure is not self._UNSET:
192             self._failure.raiseException()
193         if self._success is not self._UNSET:
194             return self._success
195         raise NoResultError()
196
197     def _got_failure(self, result):
198         self._cancel_timeout()
199         self._failure = result
200
201     def _got_success(self, result):
202         self._cancel_timeout()
203         self._success = result
204
205     def _stop_reactor(self, ignored=None):
206         """Stop the reactor!"""
207         self._reactor.crash()
208
209     def _timed_out(self, function, timeout):
210         e = TimeoutError(function, timeout)
211         self._failure = Failure(e)
212         self._stop_reactor()
213
214     def _clean(self):
215         """Clean up any junk in the reactor.
216
217         Will always iterate the reactor a number of times equal to
218         ``Spinner._OBLIGATORY_REACTOR_ITERATIONS``.  This is to work around
219         bugs in various Twisted APIs where a Deferred fires but still leaves
220         work (e.g. cancelling a call, actually closing a connection) for the
221         reactor to do.
222         """
223         for i in range(self._OBLIGATORY_REACTOR_ITERATIONS):
224             self._reactor.iterate(0)
225         junk = []
226         for delayed_call in self._reactor.getDelayedCalls():
227             delayed_call.cancel()
228             junk.append(delayed_call)
229         for selectable in self._reactor.removeAll():
230             # Twisted sends a 'KILL' signal to selectables that provide
231             # IProcessTransport.  Since only _dumbwin32proc processes do this,
232             # we aren't going to bother.
233             junk.append(selectable)
234         if IReactorThreads.providedBy(self._reactor):
235             if self._reactor.threadpool is not None:
236                 self._reactor._stopThreadPool()
237         self._junk.extend(junk)
238         return junk
239
240     def clear_junk(self):
241         """Clear out our recorded junk.
242
243         :return: Whatever junk was there before.
244         """
245         junk = self._junk
246         self._junk = []
247         return junk
248
249     def get_junk(self):
250         """Return any junk that has been found on the reactor."""
251         return self._junk
252
253     def _save_signals(self):
254         available_signals = [
255             getattr(signal, name, None) for name in self._PRESERVED_SIGNALS]
256         self._saved_signals = [
257             (sig, signal.getsignal(sig)) for sig in available_signals if sig]
258
259     def _restore_signals(self):
260         for sig, hdlr in self._saved_signals:
261             signal.signal(sig, hdlr)
262         self._saved_signals = []
263
264     @not_reentrant
265     def run(self, timeout, function, *args, **kwargs):
266         """Run 'function' in a reactor.
267
268         If 'function' returns a Deferred, the reactor will keep spinning until
269         the Deferred fires and its chain completes or until the timeout is
270         reached -- whichever comes first.
271
272         :raise TimeoutError: If 'timeout' is reached before the Deferred
273             returned by 'function' has completed its callback chain.
274         :raise NoResultError: If the reactor is somehow interrupted before
275             the Deferred returned by 'function' has completed its callback
276             chain.
277         :raise StaleJunkError: If there's junk in the spinner from a previous
278             run.
279         :return: Whatever is at the end of the function's callback chain.  If
280             it's an error, then raise that.
281         """
282         debug = MonkeyPatcher()
283         if self._debug:
284             debug.add_patch(defer.Deferred, 'debug', True)
285             debug.add_patch(DelayedCall, 'debug', True)
286         debug.patch()
287         try:
288             junk = self.get_junk()
289             if junk:
290                 raise StaleJunkError(junk)
291             self._save_signals()
292             self._timeout_call = self._reactor.callLater(
293                 timeout, self._timed_out, function, timeout)
294             # Calling 'stop' on the reactor will make it impossible to
295             # re-start the reactor.  Since the default signal handlers for
296             # TERM, BREAK and INT all call reactor.stop(), we'll patch it over
297             # with crash.  XXX: It might be a better idea to either install
298             # custom signal handlers or to override the methods that are
299             # Twisted's signal handlers.
300             stop, self._reactor.stop = self._reactor.stop, self._reactor.crash
301             def run_function():
302                 d = defer.maybeDeferred(function, *args, **kwargs)
303                 d.addCallbacks(self._got_success, self._got_failure)
304                 d.addBoth(self._stop_reactor)
305             try:
306                 self._reactor.callWhenRunning(run_function)
307                 self._reactor.run()
308             finally:
309                 self._reactor.stop = stop
310                 self._restore_signals()
311             try:
312                 return self._get_result()
313             finally:
314                 self._clean()
315         finally:
316             debug.restore()