lib: Protect against tevent nterror mismatches
[nivanova/samba-autobuild/.git] / lib / testtools / testtools / tests / test_spinner.py
1 # Copyright (c) 2010 Jonathan M. Lange. See LICENSE for details.
2
3 """Tests for the evil Twisted reactor-spinning we do."""
4
5 import os
6 import signal
7
8 from testtools import (
9     skipIf,
10     TestCase,
11     )
12 from testtools.helpers import try_import
13 from testtools.matchers import (
14     Equals,
15     Is,
16     MatchesException,
17     Raises,
18     )
19
20 _spinner = try_import('testtools._spinner')
21
22 defer = try_import('twisted.internet.defer')
23 Failure = try_import('twisted.python.failure.Failure')
24
25
26 class NeedsTwistedTestCase(TestCase):
27
28     def setUp(self):
29         super(NeedsTwistedTestCase, self).setUp()
30         if defer is None or Failure is None:
31             self.skipTest("Need Twisted to run")
32
33
34 class TestNotReentrant(NeedsTwistedTestCase):
35
36     def test_not_reentrant(self):
37         # A function decorated as not being re-entrant will raise a
38         # _spinner.ReentryError if it is called while it is running.
39         calls = []
40         @_spinner.not_reentrant
41         def log_something():
42             calls.append(None)
43             if len(calls) < 5:
44                 log_something()
45         self.assertThat(
46             log_something, Raises(MatchesException(_spinner.ReentryError)))
47         self.assertEqual(1, len(calls))
48
49     def test_deeper_stack(self):
50         calls = []
51         @_spinner.not_reentrant
52         def g():
53             calls.append(None)
54             if len(calls) < 5:
55                 f()
56         @_spinner.not_reentrant
57         def f():
58             calls.append(None)
59             if len(calls) < 5:
60                 g()
61         self.assertThat(f, Raises(MatchesException(_spinner.ReentryError)))
62         self.assertEqual(2, len(calls))
63
64
65 class TestExtractResult(NeedsTwistedTestCase):
66
67     def test_not_fired(self):
68         # _spinner.extract_result raises _spinner.DeferredNotFired if it's
69         # given a Deferred that has not fired.
70         self.assertThat(lambda:_spinner.extract_result(defer.Deferred()),
71             Raises(MatchesException(_spinner.DeferredNotFired)))
72
73     def test_success(self):
74         # _spinner.extract_result returns the value of the Deferred if it has
75         # fired successfully.
76         marker = object()
77         d = defer.succeed(marker)
78         self.assertThat(_spinner.extract_result(d), Equals(marker))
79
80     def test_failure(self):
81         # _spinner.extract_result raises the failure's exception if it's given
82         # a Deferred that is failing.
83         try:
84             1/0
85         except ZeroDivisionError:
86             f = Failure()
87         d = defer.fail(f)
88         self.assertThat(lambda:_spinner.extract_result(d),
89             Raises(MatchesException(ZeroDivisionError)))
90
91
92 class TestTrapUnhandledErrors(NeedsTwistedTestCase):
93
94     def test_no_deferreds(self):
95         marker = object()
96         result, errors = _spinner.trap_unhandled_errors(lambda: marker)
97         self.assertEqual([], errors)
98         self.assertIs(marker, result)
99
100     def test_unhandled_error(self):
101         failures = []
102         def make_deferred_but_dont_handle():
103             try:
104                 1/0
105             except ZeroDivisionError:
106                 f = Failure()
107                 failures.append(f)
108                 defer.fail(f)
109         result, errors = _spinner.trap_unhandled_errors(
110             make_deferred_but_dont_handle)
111         self.assertIs(None, result)
112         self.assertEqual(failures, [error.failResult for error in errors])
113
114
115 class TestRunInReactor(NeedsTwistedTestCase):
116
117     def make_reactor(self):
118         from twisted.internet import reactor
119         return reactor
120
121     def make_spinner(self, reactor=None):
122         if reactor is None:
123             reactor = self.make_reactor()
124         return _spinner.Spinner(reactor)
125
126     def make_timeout(self):
127         return 0.01
128
129     def test_function_called(self):
130         # run_in_reactor actually calls the function given to it.
131         calls = []
132         marker = object()
133         self.make_spinner().run(self.make_timeout(), calls.append, marker)
134         self.assertThat(calls, Equals([marker]))
135
136     def test_return_value_returned(self):
137         # run_in_reactor returns the value returned by the function given to
138         # it.
139         marker = object()
140         result = self.make_spinner().run(self.make_timeout(), lambda: marker)
141         self.assertThat(result, Is(marker))
142
143     def test_exception_reraised(self):
144         # If the given function raises an error, run_in_reactor re-raises that
145         # error.
146         self.assertThat(
147             lambda:self.make_spinner().run(self.make_timeout(), lambda: 1/0),
148             Raises(MatchesException(ZeroDivisionError)))
149
150     def test_keyword_arguments(self):
151         # run_in_reactor passes keyword arguments on.
152         calls = []
153         function = lambda *a, **kw: calls.extend([a, kw])
154         self.make_spinner().run(self.make_timeout(), function, foo=42)
155         self.assertThat(calls, Equals([(), {'foo': 42}]))
156
157     def test_not_reentrant(self):
158         # run_in_reactor raises an error if it is called inside another call
159         # to run_in_reactor.
160         spinner = self.make_spinner()
161         self.assertThat(lambda: spinner.run(
162             self.make_timeout(), spinner.run, self.make_timeout(),
163             lambda: None), Raises(MatchesException(_spinner.ReentryError)))
164
165     def test_deferred_value_returned(self):
166         # If the given function returns a Deferred, run_in_reactor returns the
167         # value in the Deferred at the end of the callback chain.
168         marker = object()
169         result = self.make_spinner().run(
170             self.make_timeout(), lambda: defer.succeed(marker))
171         self.assertThat(result, Is(marker))
172
173     def test_preserve_signal_handler(self):
174         signals = ['SIGINT', 'SIGTERM', 'SIGCHLD']
175         signals = filter(
176             None, (getattr(signal, name, None) for name in signals))
177         for sig in signals:
178             self.addCleanup(signal.signal, sig, signal.getsignal(sig))
179         new_hdlrs = list(lambda *a: None for _ in signals)
180         for sig, hdlr in zip(signals, new_hdlrs):
181             signal.signal(sig, hdlr)
182         spinner = self.make_spinner()
183         spinner.run(self.make_timeout(), lambda: None)
184         self.assertEqual(new_hdlrs, map(signal.getsignal, signals))
185
186     def test_timeout(self):
187         # If the function takes too long to run, we raise a
188         # _spinner.TimeoutError.
189         timeout = self.make_timeout()
190         self.assertThat(
191             lambda:self.make_spinner().run(timeout, lambda: defer.Deferred()),
192             Raises(MatchesException(_spinner.TimeoutError)))
193
194     def test_no_junk_by_default(self):
195         # If the reactor hasn't spun yet, then there cannot be any junk.
196         spinner = self.make_spinner()
197         self.assertThat(spinner.get_junk(), Equals([]))
198
199     def test_clean_do_nothing(self):
200         # If there's nothing going on in the reactor, then clean does nothing
201         # and returns an empty list.
202         spinner = self.make_spinner()
203         result = spinner._clean()
204         self.assertThat(result, Equals([]))
205
206     def test_clean_delayed_call(self):
207         # If there's a delayed call in the reactor, then clean cancels it and
208         # returns an empty list.
209         reactor = self.make_reactor()
210         spinner = self.make_spinner(reactor)
211         call = reactor.callLater(10, lambda: None)
212         results = spinner._clean()
213         self.assertThat(results, Equals([call]))
214         self.assertThat(call.active(), Equals(False))
215
216     def test_clean_delayed_call_cancelled(self):
217         # If there's a delayed call that's just been cancelled, then it's no
218         # longer there.
219         reactor = self.make_reactor()
220         spinner = self.make_spinner(reactor)
221         call = reactor.callLater(10, lambda: None)
222         call.cancel()
223         results = spinner._clean()
224         self.assertThat(results, Equals([]))
225
226     def test_clean_selectables(self):
227         # If there's still a selectable (e.g. a listening socket), then
228         # clean() removes it from the reactor's registry.
229         #
230         # Note that the socket is left open. This emulates a bug in trial.
231         from twisted.internet.protocol import ServerFactory
232         reactor = self.make_reactor()
233         spinner = self.make_spinner(reactor)
234         port = reactor.listenTCP(0, ServerFactory())
235         spinner.run(self.make_timeout(), lambda: None)
236         results = spinner.get_junk()
237         self.assertThat(results, Equals([port]))
238
239     def test_clean_running_threads(self):
240         import threading
241         import time
242         current_threads = list(threading.enumerate())
243         reactor = self.make_reactor()
244         timeout = self.make_timeout()
245         spinner = self.make_spinner(reactor)
246         spinner.run(timeout, reactor.callInThread, time.sleep, timeout / 2.0)
247         self.assertThat(list(threading.enumerate()), Equals(current_threads))
248
249     def test_leftover_junk_available(self):
250         # If 'run' is given a function that leaves the reactor dirty in some
251         # way, 'run' will clean up the reactor and then store information
252         # about the junk. This information can be got using get_junk.
253         from twisted.internet.protocol import ServerFactory
254         reactor = self.make_reactor()
255         spinner = self.make_spinner(reactor)
256         port = spinner.run(
257             self.make_timeout(), reactor.listenTCP, 0, ServerFactory())
258         self.assertThat(spinner.get_junk(), Equals([port]))
259
260     def test_will_not_run_with_previous_junk(self):
261         # If 'run' is called and there's still junk in the spinner's junk
262         # list, then the spinner will refuse to run.
263         from twisted.internet.protocol import ServerFactory
264         reactor = self.make_reactor()
265         spinner = self.make_spinner(reactor)
266         timeout = self.make_timeout()
267         spinner.run(timeout, reactor.listenTCP, 0, ServerFactory())
268         self.assertThat(lambda: spinner.run(timeout, lambda: None),
269             Raises(MatchesException(_spinner.StaleJunkError)))
270
271     def test_clear_junk_clears_previous_junk(self):
272         # If 'run' is called and there's still junk in the spinner's junk
273         # list, then the spinner will refuse to run.
274         from twisted.internet.protocol import ServerFactory
275         reactor = self.make_reactor()
276         spinner = self.make_spinner(reactor)
277         timeout = self.make_timeout()
278         port = spinner.run(timeout, reactor.listenTCP, 0, ServerFactory())
279         junk = spinner.clear_junk()
280         self.assertThat(junk, Equals([port]))
281         self.assertThat(spinner.get_junk(), Equals([]))
282
283     @skipIf(os.name != "posix", "Sending SIGINT with os.kill is posix only")
284     def test_sigint_raises_no_result_error(self):
285         # If we get a SIGINT during a run, we raise _spinner.NoResultError.
286         SIGINT = getattr(signal, 'SIGINT', None)
287         if not SIGINT:
288             self.skipTest("SIGINT not available")
289         reactor = self.make_reactor()
290         spinner = self.make_spinner(reactor)
291         timeout = self.make_timeout()
292         reactor.callLater(timeout, os.kill, os.getpid(), SIGINT)
293         self.assertThat(lambda:spinner.run(timeout * 5, defer.Deferred),
294             Raises(MatchesException(_spinner.NoResultError)))
295         self.assertEqual([], spinner._clean())
296
297     @skipIf(os.name != "posix", "Sending SIGINT with os.kill is posix only")
298     def test_sigint_raises_no_result_error_second_time(self):
299         # If we get a SIGINT during a run, we raise _spinner.NoResultError.
300         # This test is exactly the same as test_sigint_raises_no_result_error,
301         # and exists to make sure we haven't futzed with state.
302         self.test_sigint_raises_no_result_error()
303
304     @skipIf(os.name != "posix", "Sending SIGINT with os.kill is posix only")
305     def test_fast_sigint_raises_no_result_error(self):
306         # If we get a SIGINT during a run, we raise _spinner.NoResultError.
307         SIGINT = getattr(signal, 'SIGINT', None)
308         if not SIGINT:
309             self.skipTest("SIGINT not available")
310         reactor = self.make_reactor()
311         spinner = self.make_spinner(reactor)
312         timeout = self.make_timeout()
313         reactor.callWhenRunning(os.kill, os.getpid(), SIGINT)
314         self.assertThat(lambda:spinner.run(timeout * 5, defer.Deferred),
315             Raises(MatchesException(_spinner.NoResultError)))
316         self.assertEqual([], spinner._clean())
317
318     @skipIf(os.name != "posix", "Sending SIGINT with os.kill is posix only")
319     def test_fast_sigint_raises_no_result_error_second_time(self):
320         self.test_fast_sigint_raises_no_result_error()
321
322
323 def test_suite():
324     from unittest import TestLoader
325     return TestLoader().loadTestsFromName(__name__)