Merge python API changes for tagging.
[third_party/subunit] / python / subunit / __init__.py
1 #
2 #  subunit: extensions to Python unittest to get test results from subprocesses.
3 #  Copyright (C) 2005  Robert Collins <robertc@robertcollins.net>
4 #
5 #  This program is free software; you can redistribute it and/or modify
6 #  it under the terms of the GNU General Public License as published by
7 #  the Free Software Foundation; either version 2 of the License, or
8 #  (at your option) any later version.
9 #
10 #  This program is distributed in the hope that it will be useful,
11 #  but WITHOUT ANY WARRANTY; without even the implied warranty of
12 #  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 #  GNU General Public License for more details.
14 #
15 #  You should have received a copy of the GNU General Public License
16 #  along with this program; if not, write to the Free Software
17 #  Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
18 #
19
20 import datetime
21 import os
22 import re
23 from StringIO import StringIO
24 import subprocess
25 import sys
26 import unittest
27
28 import iso8601
29
30
31 PROGRESS_SET = 0
32 PROGRESS_CUR = 1
33 PROGRESS_PUSH = 2
34 PROGRESS_POP = 3
35
36
37 def test_suite():
38     import subunit.tests
39     return subunit.tests.test_suite()
40
41
42 def join_dir(base_path, path):
43     """
44     Returns an absolute path to C{path}, calculated relative to the parent
45     of C{base_path}.
46
47     @param base_path: A path to a file or directory.
48     @param path: An absolute path, or a path relative to the containing
49     directory of C{base_path}.
50
51     @return: An absolute path to C{path}.
52     """
53     return os.path.join(os.path.dirname(os.path.abspath(base_path)), path)
54
55
56 def tags_to_new_gone(tags):
57     """Split a list of tags into a new_set and a gone_set."""
58     new_tags = set()
59     gone_tags = set()
60     for tag in tags:
61         if tag[0] == '-':
62             gone_tags.add(tag[1:])
63         else:
64             new_tags.add(tag)
65     return new_tags, gone_tags
66
67
68 class DiscardStream(object):
69     """A filelike object which discards what is written to it."""
70
71     def write(self, bytes):
72         pass
73
74
75 class TestProtocolServer(object):
76     """A class for receiving results from a TestProtocol client.
77     
78     :ivar tags: The current tags associated with the protocol stream.
79     """
80
81     OUTSIDE_TEST = 0
82     TEST_STARTED = 1
83     READING_FAILURE = 2
84     READING_ERROR = 3
85     READING_SKIP = 4
86     READING_XFAIL = 5
87     READING_SUCCESS = 6
88
89     def __init__(self, client, stream=None):
90         """Create a TestProtocol server instance.
91
92         :param client: An object meeting the unittest.TestResult protocol.
93         :param stream: The stream that lines received which are not part of the
94             subunit protocol should be written to. This allows custom handling
95             of mixed protocols. By default, sys.stdout will be used for
96             convenience.
97         """
98         self.state = TestProtocolServer.OUTSIDE_TEST
99         self.client = client
100         if stream is None:
101             stream = sys.stdout
102         self._stream = stream
103
104     def _addError(self, offset, line):
105         if (self.state == TestProtocolServer.TEST_STARTED and
106             self.current_test_description == line[offset:-1]):
107             self.state = TestProtocolServer.OUTSIDE_TEST
108             self.current_test_description = None
109             self.client.addError(self._current_test, RemoteError(""))
110             self.client.stopTest(self._current_test)
111             self._current_test = None
112         elif (self.state == TestProtocolServer.TEST_STARTED and
113             self.current_test_description + " [" == line[offset:-1]):
114             self.state = TestProtocolServer.READING_ERROR
115             self._message = ""
116         else:
117             self.stdOutLineReceived(line)
118
119     def _addExpectedFail(self, offset, line):
120         if (self.state == TestProtocolServer.TEST_STARTED and
121             self.current_test_description == line[offset:-1]):
122             self.state = TestProtocolServer.OUTSIDE_TEST
123             self.current_test_description = None
124             xfail = getattr(self.client, 'addExpectedFailure', None)
125             if callable(xfail):
126                 xfail(self._current_test, RemoteError())
127             else:
128                 self.client.addSuccess(self._current_test)
129             self.client.stopTest(self._current_test)
130         elif (self.state == TestProtocolServer.TEST_STARTED and
131             self.current_test_description + " [" == line[offset:-1]):
132             self.state = TestProtocolServer.READING_XFAIL
133             self._message = ""
134         else:
135             self.stdOutLineReceived(line)
136
137     def _addFailure(self, offset, line):
138         if (self.state == TestProtocolServer.TEST_STARTED and
139             self.current_test_description == line[offset:-1]):
140             self.state = TestProtocolServer.OUTSIDE_TEST
141             self.current_test_description = None
142             self.client.addFailure(self._current_test, RemoteError())
143             self.client.stopTest(self._current_test)
144         elif (self.state == TestProtocolServer.TEST_STARTED and
145             self.current_test_description + " [" == line[offset:-1]):
146             self.state = TestProtocolServer.READING_FAILURE
147             self._message = ""
148         else:
149             self.stdOutLineReceived(line)
150
151     def _addSkip(self, offset, line):
152         if (self.state == TestProtocolServer.TEST_STARTED and
153             self.current_test_description == line[offset:-1]):
154             self.state = TestProtocolServer.OUTSIDE_TEST
155             self.current_test_description = None
156             self._skip_or_error()
157             self.client.stopTest(self._current_test)
158         elif (self.state == TestProtocolServer.TEST_STARTED and
159             self.current_test_description + " [" == line[offset:-1]):
160             self.state = TestProtocolServer.READING_SKIP
161             self._message = ""
162         else:
163             self.stdOutLineReceived(line)
164
165     def _skip_or_error(self, message=None):
166         """Report the current test as a skip if possible, or else an error."""
167         addSkip = getattr(self.client, 'addSkip', None)
168         if not callable(addSkip):
169             self.client.addError(self._current_test, RemoteError(message))
170         else:
171             if not message:
172                 message = "No reason given"
173             addSkip(self._current_test, message)
174
175     def _addSuccess(self, offset, line):
176         if (self.state == TestProtocolServer.TEST_STARTED and
177             self.current_test_description == line[offset:-1]):
178             self._succeedTest()
179         elif (self.state == TestProtocolServer.TEST_STARTED and
180             self.current_test_description + " [" == line[offset:-1]):
181             self.state = TestProtocolServer.READING_SUCCESS
182             self._message = ""
183         else:
184             self.stdOutLineReceived(line)
185
186     def _appendMessage(self, line):
187         if line[0:2] == " ]":
188             # quoted ] start
189             self._message += line[1:]
190         else:
191             self._message += line
192
193     def endQuote(self, line):
194         if self.state == TestProtocolServer.READING_FAILURE:
195             self.state = TestProtocolServer.OUTSIDE_TEST
196             self.current_test_description = None
197             self.client.addFailure(self._current_test,
198                                    RemoteError(self._message))
199             self.client.stopTest(self._current_test)
200         elif self.state == TestProtocolServer.READING_ERROR:
201             self.state = TestProtocolServer.OUTSIDE_TEST
202             self.current_test_description = None
203             self.client.addError(self._current_test,
204                                  RemoteError(self._message))
205             self.client.stopTest(self._current_test)
206         elif self.state == TestProtocolServer.READING_SKIP:
207             self.state = TestProtocolServer.OUTSIDE_TEST
208             self.current_test_description = None
209             self._skip_or_error(self._message)
210             self.client.stopTest(self._current_test)
211         elif self.state == TestProtocolServer.READING_XFAIL:
212             self.state = TestProtocolServer.OUTSIDE_TEST
213             self.current_test_description = None
214             xfail = getattr(self.client, 'addExpectedFailure', None)
215             if callable(xfail):
216                 xfail(self._current_test, RemoteError(self._message))
217             else:
218                 self.client.addSuccess(self._current_test)
219             self.client.stopTest(self._current_test)
220         elif self.state == TestProtocolServer.READING_SUCCESS:
221             self._succeedTest()
222         else:
223             self.stdOutLineReceived(line)
224
225     def _handleProgress(self, offset, line):
226         """Process a progress directive."""
227         line = line[offset:].strip()
228         if line[0] in '+-':
229             whence = PROGRESS_CUR
230             delta = int(line)
231         elif line == "push":
232             whence = PROGRESS_PUSH
233             delta = None
234         elif line == "pop":
235             whence = PROGRESS_POP
236             delta = None
237         else:
238             whence = PROGRESS_SET
239             delta = int(line)
240         progress_method = getattr(self.client, 'progress', None)
241         if callable(progress_method):
242             progress_method(delta, whence)
243
244     def _handleTags(self, offset, line):
245         """Process a tags command."""
246         tags = line[offset:].split()
247         new_tags, gone_tags = tags_to_new_gone(tags)
248         tags_method = getattr(self.client, 'tags', None)
249         if tags_method is not None:
250             tags_method(new_tags, gone_tags)
251
252     def _handleTime(self, offset, line):
253         # Accept it, but do not do anything with it yet.
254         try:
255             event_time = iso8601.parse_date(line[offset:-1])
256         except TypeError, e:
257             raise TypeError("Failed to parse %r, got %r" % (line, e))
258         time_method = getattr(self.client, 'time', None)
259         if callable(time_method):
260             time_method(event_time)
261
262     def lineReceived(self, line):
263         """Call the appropriate local method for the received line."""
264         if line == "]\n":
265             self.endQuote(line)
266         elif self.state in (TestProtocolServer.READING_FAILURE,
267             TestProtocolServer.READING_ERROR, TestProtocolServer.READING_SKIP,
268             TestProtocolServer.READING_SUCCESS,
269             TestProtocolServer.READING_XFAIL
270             ):
271             self._appendMessage(line)
272         else:
273             parts = line.split(None, 1)
274             if len(parts) == 2:
275                 cmd, rest = parts
276                 offset = len(cmd) + 1
277                 cmd = cmd.strip(':')
278                 if cmd in ('test', 'testing'):
279                     self._startTest(offset, line)
280                 elif cmd == 'error':
281                     self._addError(offset, line)
282                 elif cmd == 'failure':
283                     self._addFailure(offset, line)
284                 elif cmd == 'progress':
285                     self._handleProgress(offset, line)
286                 elif cmd == 'skip':
287                     self._addSkip(offset, line)
288                 elif cmd in ('success', 'successful'):
289                     self._addSuccess(offset, line)
290                 elif cmd in ('tags',):
291                     self._handleTags(offset, line)
292                 elif cmd in ('time',):
293                     self._handleTime(offset, line)
294                 elif cmd == 'xfail':
295                     self._addExpectedFail(offset, line)
296                 else:
297                     self.stdOutLineReceived(line)
298             else:
299                 self.stdOutLineReceived(line)
300
301     def _lostConnectionInTest(self, state_string):
302         error_string = "lost connection during %stest '%s'" % (
303             state_string, self.current_test_description)
304         self.client.addError(self._current_test, RemoteError(error_string))
305         self.client.stopTest(self._current_test)
306
307     def lostConnection(self):
308         """The input connection has finished."""
309         if self.state == TestProtocolServer.OUTSIDE_TEST:
310             return
311         if self.state == TestProtocolServer.TEST_STARTED:
312             self._lostConnectionInTest('')
313         elif self.state == TestProtocolServer.READING_ERROR:
314             self._lostConnectionInTest('error report of ')
315         elif self.state == TestProtocolServer.READING_FAILURE:
316             self._lostConnectionInTest('failure report of ')
317         elif self.state == TestProtocolServer.READING_SUCCESS:
318             self._lostConnectionInTest('success report of ')
319         elif self.state == TestProtocolServer.READING_SKIP:
320             self._lostConnectionInTest('skip report of ')
321         elif self.state == TestProtocolServer.READING_XFAIL:
322             self._lostConnectionInTest('xfail report of ')
323         else:
324             self._lostConnectionInTest('unknown state of ')
325
326     def readFrom(self, pipe):
327         for line in pipe.readlines():
328             self.lineReceived(line)
329         self.lostConnection()
330
331     def _startTest(self, offset, line):
332         """Internal call to change state machine. Override startTest()."""
333         if self.state == TestProtocolServer.OUTSIDE_TEST:
334             self.state = TestProtocolServer.TEST_STARTED
335             self._current_test = RemotedTestCase(line[offset:-1])
336             self.current_test_description = line[offset:-1]
337             self.client.startTest(self._current_test)
338         else:
339             self.stdOutLineReceived(line)
340
341     def stdOutLineReceived(self, line):
342         self._stream.write(line)
343
344     def _succeedTest(self):
345         self.client.addSuccess(self._current_test)
346         self.client.stopTest(self._current_test)
347         self.current_test_description = None
348         self._current_test = None
349         self.state = TestProtocolServer.OUTSIDE_TEST
350
351
352 class RemoteException(Exception):
353     """An exception that occured remotely to Python."""
354
355     def __eq__(self, other):
356         try:
357             return self.args == other.args
358         except AttributeError:
359             return False
360
361
362 class TestProtocolClient(unittest.TestResult):
363     """A class that looks like a TestResult and informs a TestProtocolServer."""
364
365     def __init__(self, stream):
366         unittest.TestResult.__init__(self)
367         self._stream = stream
368
369     def addError(self, test, error):
370         """Report an error in test test."""
371         self._stream.write("error: %s [\n" % test.id())
372         for line in self._exc_info_to_string(error, test).splitlines():
373             self._stream.write("%s\n" % line)
374         self._stream.write("]\n")
375
376     def addFailure(self, test, error):
377         """Report a failure in test test."""
378         self._stream.write("failure: %s [\n" % test.id())
379         for line in self._exc_info_to_string(error, test).splitlines():
380             self._stream.write("%s\n" % line)
381         self._stream.write("]\n")
382
383     def addSkip(self, test, reason):
384         """Report a skipped test."""
385         self._stream.write("skip: %s [\n" % test.id())
386         self._stream.write("%s\n" % reason)
387         self._stream.write("]\n")
388
389     def addSuccess(self, test):
390         """Report a success in a test."""
391         self._stream.write("successful: %s\n" % test.id())
392
393     def startTest(self, test):
394         """Mark a test as starting its test run."""
395         self._stream.write("test: %s\n" % test.id())
396
397     def progress(self, offset, whence):
398         """Provide indication about the progress/length of the test run.
399
400         :param offset: Information about the number of tests remaining. If
401             whence is PROGRESS_CUR, then offset increases/decreases the
402             remaining test count. If whence is PROGRESS_SET, then offset
403             specifies exactly the remaining test count.
404         :param whence: One of PROGRESS_CUR, PROGRESS_SET, PROGRESS_PUSH,
405             PROGRESS_POP.
406         """
407         if whence == PROGRESS_CUR and offset > -1:
408             prefix = "+"
409         elif whence == PROGRESS_PUSH:
410             prefix = ""
411             offset = "push"
412         elif whence == PROGRESS_POP:
413             prefix = ""
414             offset = "pop"
415         else:
416             prefix = ""
417         self._stream.write("progress: %s%s\n" % (prefix, offset))
418
419     def time(self, a_datetime):
420         """Inform the client of the time.
421
422         ":param datetime: A datetime.datetime object.
423         """
424         time = a_datetime.astimezone(iso8601.Utc())
425         self._stream.write("time: %04d-%02d-%02d %02d:%02d:%02d.%06dZ\n" % (
426             time.year, time.month, time.day, time.hour, time.minute,
427             time.second, time.microsecond))
428
429     def done(self):
430         """Obey the testtools result.done() interface."""
431
432
433 def RemoteError(description=""):
434     if description == "":
435         description = "\n"
436     return (RemoteException, RemoteException(description), None)
437
438
439 class RemotedTestCase(unittest.TestCase):
440     """A class to represent test cases run in child processes.
441     
442     Instances of this class are used to provide the Python test API a TestCase
443     that can be printed to the screen, introspected for metadata and so on.
444     However, as they are a simply a memoisation of a test that was actually
445     run in the past by a separate process, they cannot perform any interactive
446     actions.
447     """
448
449     def __eq__ (self, other):
450         try:
451             return self.__description == other.__description
452         except AttributeError:
453             return False
454
455     def __init__(self, description):
456         """Create a psuedo test case with description description."""
457         self.__description = description
458
459     def error(self, label):
460         raise NotImplementedError("%s on RemotedTestCases is not permitted." %
461             label)
462
463     def setUp(self):
464         self.error("setUp")
465
466     def tearDown(self):
467         self.error("tearDown")
468
469     def shortDescription(self):
470         return self.__description
471
472     def id(self):
473         return "%s" % (self.__description,)
474
475     def __str__(self):
476         return "%s (%s)" % (self.__description, self._strclass())
477
478     def __repr__(self):
479         return "<%s description='%s'>" % \
480                (self._strclass(), self.__description)
481
482     def run(self, result=None):
483         if result is None: result = self.defaultTestResult()
484         result.startTest(self)
485         result.addError(self, RemoteError("Cannot run RemotedTestCases.\n"))
486         result.stopTest(self)
487
488     def _strclass(self):
489         cls = self.__class__
490         return "%s.%s" % (cls.__module__, cls.__name__)
491
492
493 class ExecTestCase(unittest.TestCase):
494     """A test case which runs external scripts for test fixtures."""
495
496     def __init__(self, methodName='runTest'):
497         """Create an instance of the class that will use the named test
498            method when executed. Raises a ValueError if the instance does
499            not have a method with the specified name.
500         """
501         unittest.TestCase.__init__(self, methodName)
502         testMethod = getattr(self, methodName)
503         self.script = join_dir(sys.modules[self.__class__.__module__].__file__,
504                                testMethod.__doc__)
505
506     def countTestCases(self):
507         return 1
508
509     def run(self, result=None):
510         if result is None: result = self.defaultTestResult()
511         self._run(result)
512
513     def debug(self):
514         """Run the test without collecting errors in a TestResult"""
515         self._run(unittest.TestResult())
516
517     def _run(self, result):
518         protocol = TestProtocolServer(result)
519         output = subprocess.Popen(self.script, shell=True,
520             stdout=subprocess.PIPE).communicate()[0]
521         protocol.readFrom(StringIO(output))
522
523
524 class IsolatedTestCase(unittest.TestCase):
525     """A TestCase which runs its tests in a forked process."""
526
527     def run(self, result=None):
528         if result is None: result = self.defaultTestResult()
529         run_isolated(unittest.TestCase, self, result)
530
531
532 class IsolatedTestSuite(unittest.TestSuite):
533     """A TestCase which runs its tests in a forked process."""
534
535     def run(self, result=None):
536         if result is None: result = unittest.TestResult()
537         run_isolated(unittest.TestSuite, self, result)
538
539
540 def run_isolated(klass, self, result):
541     """Run a test suite or case in a subprocess, using the run method on klass.
542     """
543     c2pread, c2pwrite = os.pipe()
544     # fixme - error -> result
545     # now fork
546     pid = os.fork()
547     if pid == 0:
548         # Child
549         # Close parent's pipe ends
550         os.close(c2pread)
551         # Dup fds for child
552         os.dup2(c2pwrite, 1)
553         # Close pipe fds.
554         os.close(c2pwrite)
555
556         # at this point, sys.stdin is redirected, now we want
557         # to filter it to escape ]'s.
558         ### XXX: test and write that bit.
559
560         result = TestProtocolClient(sys.stdout)
561         klass.run(self, result)
562         sys.stdout.flush()
563         sys.stderr.flush()
564         # exit HARD, exit NOW.
565         os._exit(0)
566     else:
567         # Parent
568         # Close child pipe ends
569         os.close(c2pwrite)
570         # hookup a protocol engine
571         protocol = TestProtocolServer(result)
572         protocol.readFrom(os.fdopen(c2pread, 'rU'))
573         os.waitpid(pid, 0)
574         # TODO return code evaluation.
575     return result
576
577
578 def TAP2SubUnit(tap, subunit):
579     """Filter a TAP pipe into a subunit pipe.
580     
581     :param tap: A tap pipe/stream/file object.
582     :param subunit: A pipe/stream/file object to write subunit results to.
583     :return: The exit code to exit with.
584     """
585     BEFORE_PLAN = 0
586     AFTER_PLAN = 1
587     SKIP_STREAM = 2
588     client = TestProtocolClient(subunit)
589     state = BEFORE_PLAN
590     plan_start = 1
591     plan_stop = 0
592     def _skipped_test(subunit, plan_start):
593         # Some tests were skipped.
594         subunit.write('test test %d\n' % plan_start)
595         subunit.write('error test %d [\n' % plan_start)
596         subunit.write('test missing from TAP output\n')
597         subunit.write(']\n')
598         return plan_start + 1
599     # Test data for the next test to emit
600     test_name = None
601     log = []
602     result = None
603     def _emit_test():
604         "write out a test"
605         if test_name is None:
606             return
607         subunit.write("test %s\n" % test_name)
608         if not log:
609             subunit.write("%s %s\n" % (result, test_name))
610         else:
611             subunit.write("%s %s [\n" % (result, test_name))
612         if log:
613             for line in log:
614                 subunit.write("%s\n" % line)
615             subunit.write("]\n")
616         del log[:]
617     for line in tap:
618         if state == BEFORE_PLAN:
619             match = re.match("(\d+)\.\.(\d+)\s*(?:\#\s+(.*))?\n", line)
620             if match:
621                 state = AFTER_PLAN
622                 _, plan_stop, comment = match.groups()
623                 plan_stop = int(plan_stop)
624                 if plan_start > plan_stop and plan_stop == 0:
625                     # skipped file
626                     state = SKIP_STREAM
627                     subunit.write("test file skip\n")
628                     subunit.write("skip file skip [\n")
629                     subunit.write("%s\n" % comment)
630                     subunit.write("]\n")
631                 continue
632         # not a plan line, or have seen one before
633         match = re.match("(ok|not ok)(?:\s+(\d+)?)?(?:\s+([^#]*[^#\s]+)\s*)?(?:\s+#\s+(TODO|SKIP)(?:\s+(.*))?)?\n", line)
634         if match:
635             # new test, emit current one.
636             _emit_test()
637             status, number, description, directive, directive_comment = match.groups()
638             if status == 'ok':
639                 result = 'success'
640             else:
641                 result = "failure"
642             if description is None:
643                 description = ''
644             else:
645                 description = ' ' + description
646             if directive is not None:
647                 if directive == 'TODO':
648                     result = 'xfail'
649                 elif directive == 'SKIP':
650                     result = 'skip'
651                 if directive_comment is not None:
652                     log.append(directive_comment)
653             if number is not None:
654                 number = int(number)
655                 while plan_start < number:
656                     plan_start = _skipped_test(subunit, plan_start)
657             test_name = "test %d%s" % (plan_start, description)
658             plan_start += 1
659             continue
660         match = re.match("Bail out\!(?:\s*(.*))?\n", line)
661         if match:
662             reason, = match.groups()
663             if reason is None:
664                 extra = ''
665             else:
666                 extra = ' %s' % reason
667             _emit_test()
668             test_name = "Bail out!%s" % extra
669             result = "error"
670             state = SKIP_STREAM
671             continue
672         match = re.match("\#.*\n", line)
673         if match:
674             log.append(line[:-1])
675             continue
676         subunit.write(line)
677     _emit_test()
678     while plan_start <= plan_stop:
679         # record missed tests
680         plan_start = _skipped_test(subunit, plan_start)
681     return 0
682
683
684 def tag_stream(original, filtered, tags):
685     """Alter tags on a stream.
686
687     :param original: The input stream.
688     :param filtered: The output stream.
689     :param tags: The tags to apply. As in a normal stream - a list of 'TAG' or
690         '-TAG' commands.
691
692         A 'TAG' command will add the tag to the output stream,
693         and override any existing '-TAG' command in that stream.
694         Specifically:
695          * A global 'tags: TAG' will be added to the start of the stream.
696          * Any tags commands with -TAG will have the -TAG removed.
697
698         A '-TAG' command will remove the TAG command from the stream.
699         Specifically:
700          * A 'tags: -TAG' command will be added to the start of the stream.
701          * Any 'tags: TAG' command will have 'TAG' removed from it.
702         Additionally, any redundant tagging commands (adding a tag globally
703         present, or removing a tag globally removed) are stripped as a
704         by-product of the filtering.
705     :return: 0
706     """
707     new_tags, gone_tags = tags_to_new_gone(tags)
708     def write_tags(new_tags, gone_tags):
709         if new_tags or gone_tags:
710             filtered.write("tags: " + ' '.join(new_tags))
711             if gone_tags:
712                 for tag in gone_tags:
713                     filtered.write("-" + tag)
714             filtered.write("\n")
715     write_tags(new_tags, gone_tags)
716     # TODO: use the protocol parser and thus don't mangle test comments.
717     for line in original:
718         if line.startswith("tags:"):
719             line_tags = line[5:].split()
720             line_new, line_gone = tags_to_new_gone(line_tags)
721             line_new = line_new - gone_tags
722             line_gone = line_gone - new_tags
723             write_tags(line_new, line_gone)
724         else:
725             filtered.write(line)
726     return 0
727
728
729 class ProtocolTestCase(object):
730     """A test case which reports a subunit stream."""
731
732     def __init__(self, stream, passthrough=None):
733         """Create a ProtocolTestCase reading from stream.
734
735         :param stream: A filelike object which a subunit stream can be read
736             from.
737         :param passthrough: A stream pass non subunit input on to. If not
738             supplied, the TestProtocolServer default is used.
739         """
740         self._stream = stream
741         self._passthrough = passthrough
742
743     def __call__(self, result=None):
744         return self.run(result)
745
746     def run(self, result=None):
747         if result is None:
748             result = self.defaultTestResult()
749         protocol = TestProtocolServer(result, self._passthrough)
750         line = self._stream.readline()
751         while line:
752             protocol.lineReceived(line)
753             line = self._stream.readline()
754         protocol.lostConnection()
755
756
757 class TestResultStats(unittest.TestResult):
758     """A pyunit TestResult interface implementation for making statistics.
759     
760     :ivar total_tests: The total tests seen.
761     :ivar passed_tests: The tests that passed.
762     :ivar failed_tests: The tests that failed.
763     :ivar seen_tags: The tags seen across all tests.
764     """
765
766     def __init__(self, stream):
767         """Create a TestResultStats which outputs to stream."""
768         unittest.TestResult.__init__(self)
769         self._stream = stream
770         self.failed_tests = 0
771         self.skipped_tests = 0
772         self.seen_tags = set()
773
774     @property
775     def total_tests(self):
776         return self.testsRun
777
778     def addError(self, test, err):
779         self.failed_tests += 1
780
781     def addFailure(self, test, err):
782         self.failed_tests += 1
783
784     def addSkip(self, test, reason):
785         self.skipped_tests += 1
786
787     def formatStats(self):
788         self._stream.write("Total tests:   %5d\n" % self.total_tests)
789         self._stream.write("Passed tests:  %5d\n" % self.passed_tests)
790         self._stream.write("Failed tests:  %5d\n" % self.failed_tests)
791         self._stream.write("Skipped tests: %5d\n" % self.skipped_tests)
792         tags = sorted(self.seen_tags)
793         self._stream.write("Seen tags: %s\n" % (", ".join(tags)))
794
795     @property
796     def passed_tests(self):
797         return self.total_tests - self.failed_tests - self.skipped_tests
798
799     def tags(self, new_tags, gone_tags):
800         """Accumulate the seen tags."""
801         self.seen_tags.update(new_tags)
802
803     def wasSuccessful(self):
804         """Tells whether or not this result was a success"""
805         return self.failed_tests == 0
806
807
808 class TestResultFilter(unittest.TestResult):
809     """A pyunit TestResult interface implementation which filters tests.
810
811     Tests that pass the filter are handed on to another TestResult instance
812     for further processing/reporting. To obtain the filtered results, 
813     the other instance must be interrogated.
814
815     :ivar result: The result that tests are passed to after filtering.
816     :ivar filter_predicate: The callback run to decide whether to pass 
817         a result.
818     """
819
820     def __init__(self, result, filter_error=False, filter_failure=False,
821         filter_success=True, filter_skip=False,
822         filter_predicate=None):
823         """Create a FilterResult object filtering to result.
824         
825         :param filter_error: Filter out errors.
826         :param filter_failure: Filter out failures.
827         :param filter_success: Filter out successful tests.
828         :param filter_skip: Filter out skipped tests.
829         :param filter_predicate: A callable taking (test, err) and 
830             returning True if the result should be passed through.
831             err is None for success.
832         """
833         unittest.TestResult.__init__(self)
834         self.result = result
835         self._filter_error = filter_error
836         self._filter_failure = filter_failure
837         self._filter_success = filter_success
838         self._filter_skip = filter_skip
839         if filter_predicate is None:
840             filter_predicate = lambda test, err: True
841         self.filter_predicate = filter_predicate
842         # The current test (for filtering tags)
843         self._current_test = None
844         # Has the current test been filtered (for outputting test tags)
845         self._current_test_filtered = None
846         # The (new, gone) tags for the current test.
847         self._current_test_tags = None
848         
849     def addError(self, test, err):
850         if not self._filter_error and self.filter_predicate(test, err):
851             self.result.startTest(test)
852             self.result.addError(test, err)
853
854     def addFailure(self, test, err):
855         if not self._filter_failure and self.filter_predicate(test, err):
856             self.result.startTest(test)
857             self.result.addFailure(test, err)
858
859     def addSkip(self, test, reason):
860         if not self._filter_skip and self.filter_predicate(test, reason):
861             self.result.startTest(test)
862             # This is duplicated, it would be nice to have on a 'calls
863             # TestResults' mixin perhaps.
864             addSkip = getattr(self.result, 'addSkip', None)
865             if not callable(addSkip):
866                 self.result.addError(test, RemoteError(reason))
867             else:
868                 self.result.addSkip(test, reason)
869
870     def addSuccess(self, test):
871         if not self._filter_success and self.filter_predicate(test, None):
872             self.result.startTest(test)
873             self.result.addSuccess(test)
874
875     def startTest(self, test):
876         """Start a test.
877         
878         Not directly passed to the client, but used for handling of tags
879         correctly.
880         """
881         self._current_test = test
882         self._current_test_filtered = False
883         self._current_test_tags = set(), set()
884     
885     def stopTest(self, test):
886         """Stop a test.
887         
888         Not directly passed to the client, but used for handling of tags
889         correctly.
890         """
891         if not self._current_test_filtered:
892             # Tags to output for this test.
893             if self._current_test_tags[0] or self._current_test_tags[1]:
894                 tags_method = getattr(self.result, 'tags', None)
895                 if callable(tags_method):
896                     self.result.tags(*self._current_test_tags)
897             self.result.stopTest(test)
898         self._current_test = None
899         self._current_test_filtered = None
900         self._current_test_tags = None
901
902     def tags(self, new_tags, gone_tags):
903         """Handle tag instructions.
904
905         Adds and removes tags as appropriate. If a test is currently running,
906         tags are not affected for subsequent tests.
907         
908         :param new_tags: Tags to add,
909         :param gone_tags: Tags to remove.
910         """
911         if self._current_test is not None:
912             # gather the tags until the test stops.
913             self._current_test_tags[0].update(new_tags)
914             self._current_test_tags[0].difference_update(gone_tags)
915             self._current_test_tags[1].update(gone_tags)
916             self._current_test_tags[1].difference_update(new_tags)
917         tags_method = getattr(self.result, 'tags', None)
918         if tags_method is None:
919             return
920         return tags_method(new_tags, gone_tags)
921
922     def id_to_orig_id(self, id):
923         if id.startswith("subunit.RemotedTestCase."):
924             return id[len("subunit.RemotedTestCase."):]
925         return id
926