testtools: Merge in new upstream.
[nivanova/samba-autobuild/.git] / lib / testtools / testtools / tests / test_spinner.py
1 # Copyright (c) 2010 Jonathan M. Lange. See LICENSE for details.
2
3 """Tests for the evil Twisted reactor-spinning we do."""
4
5 import os
6 import signal
7
8 from testtools import (
9     skipIf,
10     TestCase,
11     )
12 from testtools.helpers import try_import
13 from testtools.matchers import (
14     Equals,
15     Is,
16     MatchesException,
17     Raises,
18     )
19
20 _spinner = try_import('testtools._spinner')
21
22 defer = try_import('twisted.internet.defer')
23 Failure = try_import('twisted.python.failure.Failure')
24
25
26 class NeedsTwistedTestCase(TestCase):
27
28     def setUp(self):
29         super(NeedsTwistedTestCase, self).setUp()
30         if defer is None or Failure is None:
31             self.skipTest("Need Twisted to run")
32
33
34 class TestNotReentrant(NeedsTwistedTestCase):
35
36     def test_not_reentrant(self):
37         # A function decorated as not being re-entrant will raise a
38         # _spinner.ReentryError if it is called while it is running.
39         calls = []
40         @_spinner.not_reentrant
41         def log_something():
42             calls.append(None)
43             if len(calls) < 5:
44                 log_something()
45         self.assertThat(
46             log_something, Raises(MatchesException(_spinner.ReentryError)))
47         self.assertEqual(1, len(calls))
48
49     def test_deeper_stack(self):
50         calls = []
51         @_spinner.not_reentrant
52         def g():
53             calls.append(None)
54             if len(calls) < 5:
55                 f()
56         @_spinner.not_reentrant
57         def f():
58             calls.append(None)
59             if len(calls) < 5:
60                 g()
61         self.assertThat(f, Raises(MatchesException(_spinner.ReentryError)))
62         self.assertEqual(2, len(calls))
63
64
65 class TestExtractResult(NeedsTwistedTestCase):
66
67     def test_not_fired(self):
68         # _spinner.extract_result raises _spinner.DeferredNotFired if it's
69         # given a Deferred that has not fired.
70         self.assertThat(lambda:_spinner.extract_result(defer.Deferred()),
71             Raises(MatchesException(_spinner.DeferredNotFired)))
72
73     def test_success(self):
74         # _spinner.extract_result returns the value of the Deferred if it has
75         # fired successfully.
76         marker = object()
77         d = defer.succeed(marker)
78         self.assertThat(_spinner.extract_result(d), Equals(marker))
79
80     def test_failure(self):
81         # _spinner.extract_result raises the failure's exception if it's given
82         # a Deferred that is failing.
83         try:
84             1/0
85         except ZeroDivisionError:
86             f = Failure()
87         d = defer.fail(f)
88         self.assertThat(lambda:_spinner.extract_result(d),
89             Raises(MatchesException(ZeroDivisionError)))
90
91
92 class TestTrapUnhandledErrors(NeedsTwistedTestCase):
93
94     def test_no_deferreds(self):
95         marker = object()
96         result, errors = _spinner.trap_unhandled_errors(lambda: marker)
97         self.assertEqual([], errors)
98         self.assertIs(marker, result)
99
100     def test_unhandled_error(self):
101         failures = []
102         def make_deferred_but_dont_handle():
103             try:
104                 1/0
105             except ZeroDivisionError:
106                 f = Failure()
107                 failures.append(f)
108                 defer.fail(f)
109         result, errors = _spinner.trap_unhandled_errors(
110             make_deferred_but_dont_handle)
111         self.assertIs(None, result)
112         self.assertEqual(failures, [error.failResult for error in errors])
113
114
115 class TestRunInReactor(NeedsTwistedTestCase):
116
117     def make_reactor(self):
118         from twisted.internet import reactor
119         return reactor
120
121     def make_spinner(self, reactor=None):
122         if reactor is None:
123             reactor = self.make_reactor()
124         return _spinner.Spinner(reactor)
125
126     def make_timeout(self):
127         return 0.01
128
129     def test_function_called(self):
130         # run_in_reactor actually calls the function given to it.
131         calls = []
132         marker = object()
133         self.make_spinner().run(self.make_timeout(), calls.append, marker)
134         self.assertThat(calls, Equals([marker]))
135
136     def test_return_value_returned(self):
137         # run_in_reactor returns the value returned by the function given to
138         # it.
139         marker = object()
140         result = self.make_spinner().run(self.make_timeout(), lambda: marker)
141         self.assertThat(result, Is(marker))
142
143     def test_exception_reraised(self):
144         # If the given function raises an error, run_in_reactor re-raises that
145         # error.
146         self.assertThat(
147             lambda:self.make_spinner().run(self.make_timeout(), lambda: 1/0),
148             Raises(MatchesException(ZeroDivisionError)))
149
150     def test_keyword_arguments(self):
151         # run_in_reactor passes keyword arguments on.
152         calls = []
153         function = lambda *a, **kw: calls.extend([a, kw])
154         self.make_spinner().run(self.make_timeout(), function, foo=42)
155         self.assertThat(calls, Equals([(), {'foo': 42}]))
156
157     def test_not_reentrant(self):
158         # run_in_reactor raises an error if it is called inside another call
159         # to run_in_reactor.
160         spinner = self.make_spinner()
161         self.assertThat(lambda: spinner.run(
162             self.make_timeout(), spinner.run, self.make_timeout(),
163             lambda: None), Raises(MatchesException(_spinner.ReentryError)))
164
165     def test_deferred_value_returned(self):
166         # If the given function returns a Deferred, run_in_reactor returns the
167         # value in the Deferred at the end of the callback chain.
168         marker = object()
169         result = self.make_spinner().run(
170             self.make_timeout(), lambda: defer.succeed(marker))
171         self.assertThat(result, Is(marker))
172
173     def test_preserve_signal_handler(self):
174         signals = ['SIGINT', 'SIGTERM', 'SIGCHLD']
175         signals = filter(
176             None, (getattr(signal, name, None) for name in signals))
177         for sig in signals:
178             self.addCleanup(signal.signal, sig, signal.getsignal(sig))
179         new_hdlrs = list(lambda *a: None for _ in signals)
180         for sig, hdlr in zip(signals, new_hdlrs):
181             signal.signal(sig, hdlr)
182         spinner = self.make_spinner()
183         spinner.run(self.make_timeout(), lambda: None)
184         self.assertEqual(new_hdlrs, map(signal.getsignal, signals))
185
186     def test_timeout(self):
187         # If the function takes too long to run, we raise a
188         # _spinner.TimeoutError.
189         timeout = self.make_timeout()
190         self.assertThat(
191             lambda:self.make_spinner().run(timeout, lambda: defer.Deferred()),
192             Raises(MatchesException(_spinner.TimeoutError)))
193
194     def test_no_junk_by_default(self):
195         # If the reactor hasn't spun yet, then there cannot be any junk.
196         spinner = self.make_spinner()
197         self.assertThat(spinner.get_junk(), Equals([]))
198
199     def test_clean_do_nothing(self):
200         # If there's nothing going on in the reactor, then clean does nothing
201         # and returns an empty list.
202         spinner = self.make_spinner()
203         result = spinner._clean()
204         self.assertThat(result, Equals([]))
205
206     def test_clean_delayed_call(self):
207         # If there's a delayed call in the reactor, then clean cancels it and
208         # returns an empty list.
209         reactor = self.make_reactor()
210         spinner = self.make_spinner(reactor)
211         call = reactor.callLater(10, lambda: None)
212         results = spinner._clean()
213         self.assertThat(results, Equals([call]))
214         self.assertThat(call.active(), Equals(False))
215
216     def test_clean_delayed_call_cancelled(self):
217         # If there's a delayed call that's just been cancelled, then it's no
218         # longer there.
219         reactor = self.make_reactor()
220         spinner = self.make_spinner(reactor)
221         call = reactor.callLater(10, lambda: None)
222         call.cancel()
223         results = spinner._clean()
224         self.assertThat(results, Equals([]))
225
226     def test_clean_selectables(self):
227         # If there's still a selectable (e.g. a listening socket), then
228         # clean() removes it from the reactor's registry.
229         #
230         # Note that the socket is left open. This emulates a bug in trial.
231         from twisted.internet.protocol import ServerFactory
232         reactor = self.make_reactor()
233         spinner = self.make_spinner(reactor)
234         port = reactor.listenTCP(0, ServerFactory())
235         spinner.run(self.make_timeout(), lambda: None)
236         results = spinner.get_junk()
237         self.assertThat(results, Equals([port]))
238
239     def test_clean_running_threads(self):
240         import threading
241         import time
242         current_threads = list(threading.enumerate())
243         reactor = self.make_reactor()
244         timeout = self.make_timeout()
245         spinner = self.make_spinner(reactor)
246         spinner.run(timeout, reactor.callInThread, time.sleep, timeout / 2.0)
247         # Python before 2.5 has a race condition with thread handling where
248         # join() does not remove threads from enumerate before returning - the
249         # thread being joined does the removal. This was fixed in Python 2.5
250         # but we still support 2.4, so we have to workaround the issue.
251         # http://bugs.python.org/issue1703448.
252         self.assertThat(
253             [thread for thread in threading.enumerate() if thread.isAlive()],
254             Equals(current_threads))
255
256     def test_leftover_junk_available(self):
257         # If 'run' is given a function that leaves the reactor dirty in some
258         # way, 'run' will clean up the reactor and then store information
259         # about the junk. This information can be got using get_junk.
260         from twisted.internet.protocol import ServerFactory
261         reactor = self.make_reactor()
262         spinner = self.make_spinner(reactor)
263         port = spinner.run(
264             self.make_timeout(), reactor.listenTCP, 0, ServerFactory())
265         self.assertThat(spinner.get_junk(), Equals([port]))
266
267     def test_will_not_run_with_previous_junk(self):
268         # If 'run' is called and there's still junk in the spinner's junk
269         # list, then the spinner will refuse to run.
270         from twisted.internet.protocol import ServerFactory
271         reactor = self.make_reactor()
272         spinner = self.make_spinner(reactor)
273         timeout = self.make_timeout()
274         spinner.run(timeout, reactor.listenTCP, 0, ServerFactory())
275         self.assertThat(lambda: spinner.run(timeout, lambda: None),
276             Raises(MatchesException(_spinner.StaleJunkError)))
277
278     def test_clear_junk_clears_previous_junk(self):
279         # If 'run' is called and there's still junk in the spinner's junk
280         # list, then the spinner will refuse to run.
281         from twisted.internet.protocol import ServerFactory
282         reactor = self.make_reactor()
283         spinner = self.make_spinner(reactor)
284         timeout = self.make_timeout()
285         port = spinner.run(timeout, reactor.listenTCP, 0, ServerFactory())
286         junk = spinner.clear_junk()
287         self.assertThat(junk, Equals([port]))
288         self.assertThat(spinner.get_junk(), Equals([]))
289
290     @skipIf(os.name != "posix", "Sending SIGINT with os.kill is posix only")
291     def test_sigint_raises_no_result_error(self):
292         # If we get a SIGINT during a run, we raise _spinner.NoResultError.
293         SIGINT = getattr(signal, 'SIGINT', None)
294         if not SIGINT:
295             self.skipTest("SIGINT not available")
296         reactor = self.make_reactor()
297         spinner = self.make_spinner(reactor)
298         timeout = self.make_timeout()
299         reactor.callLater(timeout, os.kill, os.getpid(), SIGINT)
300         self.assertThat(lambda:spinner.run(timeout * 5, defer.Deferred),
301             Raises(MatchesException(_spinner.NoResultError)))
302         self.assertEqual([], spinner._clean())
303
304     @skipIf(os.name != "posix", "Sending SIGINT with os.kill is posix only")
305     def test_sigint_raises_no_result_error_second_time(self):
306         # If we get a SIGINT during a run, we raise _spinner.NoResultError.
307         # This test is exactly the same as test_sigint_raises_no_result_error,
308         # and exists to make sure we haven't futzed with state.
309         self.test_sigint_raises_no_result_error()
310
311     @skipIf(os.name != "posix", "Sending SIGINT with os.kill is posix only")
312     def test_fast_sigint_raises_no_result_error(self):
313         # If we get a SIGINT during a run, we raise _spinner.NoResultError.
314         SIGINT = getattr(signal, 'SIGINT', None)
315         if not SIGINT:
316             self.skipTest("SIGINT not available")
317         reactor = self.make_reactor()
318         spinner = self.make_spinner(reactor)
319         timeout = self.make_timeout()
320         reactor.callWhenRunning(os.kill, os.getpid(), SIGINT)
321         self.assertThat(lambda:spinner.run(timeout * 5, defer.Deferred),
322             Raises(MatchesException(_spinner.NoResultError)))
323         self.assertEqual([], spinner._clean())
324
325     @skipIf(os.name != "posix", "Sending SIGINT with os.kill is posix only")
326     def test_fast_sigint_raises_no_result_error_second_time(self):
327         self.test_fast_sigint_raises_no_result_error()
328
329
330 def test_suite():
331     from unittest import TestLoader
332     return TestLoader().loadTestsFromName(__name__)