3abfbf522e4832ea09fe0b8674b6c88ed523ebb0
[jra/samba/.git] / source4 / scripting / python / subunit / __init__.py
1 #
2 #  subunit: extensions to python unittest to get test results from subprocesses.
3 #  Copyright (C) 2005  Robert Collins <robertc@robertcollins.net>
4 #  Copyright (C) 2007  Jelmer Vernooij <jelmer@samba.org>
5 #
6 #  This program is free software; you can redistribute it and/or modify
7 #  it under the terms of the GNU General Public License as published by
8 #  the Free Software Foundation; either version 2 of the License, or
9 #  (at your option) any later version.
10 #
11 #  This program is distributed in the hope that it will be useful,
12 #  but WITHOUT ANY WARRANTY; without even the implied warranty of
13 #  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14 #  GNU General Public License for more details.
15 #
16 #  You should have received a copy of the GNU General Public License
17 #  along with this program; if not, write to the Free Software
18 #  Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
19 #
20
21 import os
22 from StringIO import StringIO
23 import sys
24 import unittest
25
26 def test_suite():
27     import subunit.tests
28     return subunit.tests.test_suite()
29
30
31 def join_dir(base_path, path):
32     """
33     Returns an absolute path to C{path}, calculated relative to the parent
34     of C{base_path}.
35
36     @param base_path: A path to a file or directory.
37     @param path: An absolute path, or a path relative to the containing
38     directory of C{base_path}.
39
40     @return: An absolute path to C{path}.
41     """
42     return os.path.join(os.path.dirname(os.path.abspath(base_path)), path)
43
44
45 class TestProtocolServer(object):
46     """A class for receiving results from a TestProtocol client."""
47
48     OUTSIDE_TEST = 0
49     TEST_STARTED = 1
50     READING_FAILURE = 2
51     READING_ERROR = 3
52
53     def __init__(self, client, stream=sys.stdout):
54         """Create a TestProtocol server instance.
55
56         client should be an object that provides
57          - startTest
58          - addSuccess
59          - addFailure
60          - addError
61          - stopTest
62         methods, i.e. a TestResult.
63         """
64         self.state = TestProtocolServer.OUTSIDE_TEST
65         self.client = client
66         self._stream = stream
67
68     def _addError(self, offset, line):
69         if (self.state == TestProtocolServer.TEST_STARTED and
70             self.current_test_description == line[offset:-1]):
71             self.state = TestProtocolServer.OUTSIDE_TEST
72             self.current_test_description = None
73             self.client.addError(self._current_test, RemoteError(""))
74             self.client.stopTest(self._current_test)
75             self._current_test = None
76         elif (self.state == TestProtocolServer.TEST_STARTED and
77             self.current_test_description + " [" == line[offset:-1]):
78             self.state = TestProtocolServer.READING_ERROR
79             self._message = ""
80         else:
81             self.stdOutLineReceived(line)
82
83     def _addFailure(self, offset, line):
84         if (self.state == TestProtocolServer.TEST_STARTED and
85             self.current_test_description == line[offset:-1]):
86             self.state = TestProtocolServer.OUTSIDE_TEST
87             self.current_test_description = None
88             self.client.addFailure(self._current_test, RemoteError())
89             self.client.stopTest(self._current_test)
90         elif (self.state == TestProtocolServer.TEST_STARTED and
91             self.current_test_description + " [" == line[offset:-1]):
92             self.state = TestProtocolServer.READING_FAILURE
93             self._message = ""
94         else:
95             self.stdOutLineReceived(line)
96
97     def _addSuccess(self, offset, line):
98         if (self.state == TestProtocolServer.TEST_STARTED and
99             self.current_test_description == line[offset:-1]):
100             self.client.addSuccess(self._current_test)
101             self.client.stopTest(self._current_test)
102             self.current_test_description = None
103             self._current_test = None
104             self.state = TestProtocolServer.OUTSIDE_TEST
105         else:
106             self.stdOutLineReceived(line)
107
108     def _appendMessage(self, line):
109         if line[0:2] == " ]":
110             # quoted ] start
111             self._message += line[1:]
112         else:
113             self._message += line
114
115     def endQuote(self, line):
116         if self.state == TestProtocolServer.READING_FAILURE:
117             self.state = TestProtocolServer.OUTSIDE_TEST
118             self.current_test_description = None
119             self.client.addFailure(self._current_test,
120                                    RemoteError(self._message))
121             self.client.stopTest(self._current_test)
122         elif self.state == TestProtocolServer.READING_ERROR:
123             self.state = TestProtocolServer.OUTSIDE_TEST
124             self.current_test_description = None
125             self.client.addError(self._current_test,
126                                  RemoteError(self._message))
127             self.client.stopTest(self._current_test)
128         else:
129             self.stdOutLineReceived(line)
130
131     def lineReceived(self, line):
132         """Call the appropriate local method for the received line."""
133         if line == "]\n":
134             self.endQuote(line)
135         elif (self.state == TestProtocolServer.READING_FAILURE or
136               self.state == TestProtocolServer.READING_ERROR):
137             self._appendMessage(line)
138         else:
139             parts = line.split(None, 1)
140             if len(parts) == 2:
141                 cmd, rest = parts
142                 offset = len(cmd) + 1
143                 cmd = cmd.strip(':')
144                 if cmd in ('test', 'testing'):
145                     self._startTest(offset, line)
146                 elif cmd == 'error':
147                     self._addError(offset, line)
148                 elif cmd == 'failure':
149                     self._addFailure(offset, line)
150                 elif cmd in ('success', 'successful'):
151                     self._addSuccess(offset, line)
152                 else:
153                     self.stdOutLineReceived(line)
154             else:
155                 self.stdOutLineReceived(line)
156
157     def lostConnection(self):
158         """The input connection has finished."""
159         if self.state == TestProtocolServer.TEST_STARTED:
160             self.client.addError(self._current_test,
161                                  RemoteError("lost connection during test '%s'"
162                                              % self.current_test_description))
163             self.client.stopTest(self._current_test)
164         elif self.state == TestProtocolServer.READING_ERROR:
165             self.client.addError(self._current_test,
166                                  RemoteError("lost connection during "
167                                              "error report of test "
168                                              "'%s'" %
169                                              self.current_test_description))
170             self.client.stopTest(self._current_test)
171         elif self.state == TestProtocolServer.READING_FAILURE:
172             self.client.addError(self._current_test,
173                                  RemoteError("lost connection during "
174                                              "failure report of test "
175                                              "'%s'" %
176                                              self.current_test_description))
177             self.client.stopTest(self._current_test)
178
179     def readFrom(self, pipe):
180         for line in pipe.readlines():
181             self.lineReceived(line)
182         self.lostConnection()
183
184     def _startTest(self, offset, line):
185         """Internal call to change state machine. Override startTest()."""
186         if self.state == TestProtocolServer.OUTSIDE_TEST:
187             self.state = TestProtocolServer.TEST_STARTED
188             self._current_test = RemotedTestCase(line[offset:-1])
189             self.current_test_description = line[offset:-1]
190             self.client.startTest(self._current_test)
191         else:
192             self.stdOutLineReceived(line)
193
194     def stdOutLineReceived(self, line):
195         self._stream.write(line)
196
197
198 class RemoteException(Exception):
199     """An exception that occured remotely to python."""
200
201     def __eq__(self, other):
202         try:
203             return self.args == other.args
204         except AttributeError:
205             return False
206
207
208 class TestProtocolClient(unittest.TestResult):
209     """A class that looks like a TestResult and informs a TestProtocolServer."""
210
211     def __init__(self, stream):
212         super(TestProtocolClient, self).__init__()
213         self._stream = stream
214
215     def addError(self, test, error):
216         """Report an error in test test."""
217         self._stream.write("error: %s [\n" % (test.shortDescription() or str(test)))
218         for line in self._exc_info_to_string(error, test).splitlines():
219             self._stream.write("%s\n" % line)
220         self._stream.write("]\n")
221         super(TestProtocolClient, self).addError(test, error)
222
223     def addFailure(self, test, error):
224         """Report a failure in test test."""
225         self._stream.write("failure: %s [\n" % (test.shortDescription() or str(test)))
226         for line in self._exc_info_to_string(error, test).splitlines():
227             self._stream.write("%s\n" % line)
228         self._stream.write("]\n")
229         super(TestProtocolClient, self).addFailure(test, error)
230
231     def addSuccess(self, test):
232         """Report a success in a test."""
233         self._stream.write("successful: %s\n" % (test.shortDescription() or str(test)))
234         super(TestProtocolClient, self).addSuccess(test)
235
236     def startTest(self, test):
237         """Mark a test as starting its test run."""
238         self._stream.write("test: %s\n" % (test.shortDescription() or str(test)))
239         super(TestProtocolClient, self).startTest(test)
240
241
242 def RemoteError(description=""):
243     if description == "":
244         description = "\n"
245     return (RemoteException, RemoteException(description), None)
246
247
248 class RemotedTestCase(unittest.TestCase):
249     """A class to represent test cases run in child processes."""
250
251     def __eq__ (self, other):
252         try:
253             return self.__description == other.__description
254         except AttributeError:
255             return False
256
257     def __init__(self, description):
258         """Create a psuedo test case with description description."""
259         self.__description = description
260
261     def error(self, label):
262         raise NotImplementedError("%s on RemotedTestCases is not permitted." %
263             label)
264
265     def setUp(self):
266         self.error("setUp")
267
268     def tearDown(self):
269         self.error("tearDown")
270
271     def shortDescription(self):
272         return self.__description
273
274     def id(self):
275         return "%s.%s" % (self._strclass(), self.__description)
276
277     def __str__(self):
278         return "%s (%s)" % (self.__description, self._strclass())
279
280     def __repr__(self):
281         return "<%s description='%s'>" % \
282                (self._strclass(), self.__description)
283
284     def run(self, result=None):
285         if result is None: result = self.defaultTestResult()
286         result.startTest(self)
287         result.addError(self, RemoteError("Cannot run RemotedTestCases.\n"))
288         result.stopTest(self)
289
290     def _strclass(self):
291         cls = self.__class__
292         return "%s.%s" % (cls.__module__, cls.__name__)
293
294
295 class ExecTestCase(unittest.TestCase):
296     """A test case which runs external scripts for test fixtures."""
297
298     def __init__(self, methodName='runTest'):
299         """Create an instance of the class that will use the named test
300            method when executed. Raises a ValueError if the instance does
301            not have a method with the specified name.
302         """
303         unittest.TestCase.__init__(self, methodName)
304         testMethod = getattr(self, methodName)
305         self.script = join_dir(sys.modules[self.__class__.__module__].__file__,
306                                testMethod.__doc__)
307
308     def countTestCases(self):
309         return 1
310
311     def run(self, result=None):
312         if result is None: result = self.defaultTestResult()
313         self._run(result)
314
315     def debug(self):
316         """Run the test without collecting errors in a TestResult"""
317         self._run(unittest.TestResult())
318
319     def _run(self, result):
320         protocol = TestProtocolServer(result)
321         output = os.popen(self.script, mode='r')
322         protocol.readFrom(output)
323
324
325 class IsolatedTestCase(unittest.TestCase):
326     """A TestCase which runs its tests in a forked process."""
327
328     def run(self, result=None):
329         if result is None: result = self.defaultTestResult()
330         run_isolated(unittest.TestCase, self, result)
331
332
333 class IsolatedTestSuite(unittest.TestSuite):
334     """A TestCase which runs its tests in a forked process."""
335
336     def run(self, result=None):
337         if result is None: result = unittest.TestResult()
338         run_isolated(unittest.TestSuite, self, result)
339
340
341 def run_isolated(klass, self, result):
342     """Run a test suite or case in a subprocess, using the run method on klass.
343     """
344     c2pread, c2pwrite = os.pipe()
345     # fixme - error -> result
346     # now fork
347     pid = os.fork()
348     if pid == 0:
349         # Child
350         # Close parent's pipe ends
351         os.close(c2pread)
352         # Dup fds for child
353         os.dup2(c2pwrite, 1)
354         # Close pipe fds.
355         os.close(c2pwrite)
356
357         # at this point, sys.stdin is redirected, now we want
358         # to filter it to escape ]'s.
359         ### XXX: test and write that bit.
360
361         result = TestProtocolClient(sys.stdout)
362         klass.run(self, result)
363         sys.stdout.flush()
364         sys.stderr.flush()
365         # exit HARD, exit NOW.
366         os._exit(0)
367     else:
368         # Parent
369         # Close child pipe ends
370         os.close(c2pwrite)
371         # hookup a protocol engine
372         protocol = TestProtocolServer(result)
373         protocol.readFrom(os.fdopen(c2pread, 'rU'))
374         os.waitpid(pid, 0)
375         # TODO return code evaluation.
376     return result
377
378
379 class SubunitTestRunner:
380     def __init__(self, stream=sys.stdout):
381         self.stream = stream
382
383     def run(self, test):
384         "Run the given test case or test suite."
385         result = TestProtocolClient(self.stream)
386         test(result)
387         return result
388