Handle text stdin and stdout streams.
authorRobert Collins <robertc@robertcollins.net>
Mon, 7 May 2012 20:27:36 +0000 (08:27 +1200)
committerRobert Collins <robertc@robertcollins.net>
Mon, 7 May 2012 20:27:36 +0000 (08:27 +1200)
NEWS
python/subunit/__init__.py
python/subunit/tests/test_subunit_filter.py

diff --git a/NEWS b/NEWS
index 479d9953b21b4f8b7bddf4dbed6c16532af9b69d..f97d1e1c6d6de6f2f5a71757a0fdeaa2ddf9a823 100644 (file)
--- a/NEWS
+++ b/NEWS
@@ -34,6 +34,9 @@ BUG FIXES
 * Python3 support regressed in trunk.
   (Arfrever Frehtes Taifersar Arahesis, #987514)
 
+* Python3 support was insufficiently robust in detecting unicode streams.
+  (Robert Collins)
+
 * Tag support has been implemented for TestProtocolClient.
   (Robert Collins, #518016)
 
index 8f2a9eda6a53a3ca56a37300fa92fcd87c0db159..69ccc26e025a6d31f4b3636ec694bac3e595f8db 100644 (file)
@@ -121,6 +121,11 @@ import re
 import subprocess
 import sys
 import unittest
+if sys.version_info > (3, 0):
+    from io import UnsupportedOperation as _UnsupportedOperation
+else:
+    _UnsupportedOperation = AttributeError
+
 
 from testtools import content, content_type, ExtendedToOriginalDecorator
 from testtools.content import TracebackContent
@@ -183,9 +188,15 @@ def tags_to_new_gone(tags):
 class DiscardStream(object):
     """A filelike object which discards what is written to it."""
 
+    def fileno(self):
+        raise _UnsupportedOperation()
+
     def write(self, bytes):
         pass
 
+    def read(self, len=0):
+        return _b('')
+
 
 class _ParserState(object):
     """State for the subunit parser."""
@@ -600,8 +611,8 @@ class TestProtocolClient(testresult.TestResult):
 
     def __init__(self, stream):
         testresult.TestResult.__init__(self)
+        stream = _make_stream_binary(stream)
         self._stream = stream
-        _make_stream_binary(stream)
         self._progress_fmt = _b("progress: ")
         self._bytes_eol = _b("\n")
         self._progress_plus = _b("+")
@@ -1141,11 +1152,11 @@ class ProtocolTestCase(object):
         :param forward: A stream to pass subunit input on to. If not supplied
             subunit input is not forwarded.
         """
+        stream = _make_stream_binary(stream)
         self._stream = stream
-        _make_stream_binary(stream)
         self._passthrough = passthrough
         if forward is not None:
-            _make_stream_binary(forward)
+            forward = _make_stream_binary(forward)
         self._forward = forward
 
     def __call__(self, result=None):
@@ -1228,11 +1239,6 @@ def get_default_formatter():
         return stream
 
 
-if sys.version_info > (3, 0):
-    from io import UnsupportedOperation as _NoFilenoError
-else:
-    _NoFilenoError = AttributeError
-
 def read_test_list(path):
     """Read a list of test ids from a file on disk.
 
@@ -1247,15 +1253,37 @@ def read_test_list(path):
 
 
 def _make_stream_binary(stream):
-    """Ensure that a stream will be binary safe. See _make_binary_on_windows."""
+    """Ensure that a stream will be binary safe. See _make_binary_on_windows.
+    
+    :return: A binary version of the same stream (some streams cannot be
+        'fixed' but can be unwrapped).
+    """
     try:
         fileno = stream.fileno()
-    except _NoFilenoError:
-        return
-    _make_binary_on_windows(fileno)
+    except _UnsupportedOperation:
+        pass
+    else:
+        _make_binary_on_windows(fileno)
+    return _unwrap_text(stream)
 
 def _make_binary_on_windows(fileno):
     """Win32 mangles \r\n to \n and that breaks streams. See bug lp:505078."""
     if sys.platform == "win32":
         import msvcrt
         msvcrt.setmode(fileno, os.O_BINARY)
+
+
+def _unwrap_text(stream):
+    """Unwrap stream if it is a text stream to get the original buffer."""
+    if sys.version_info > (3, 0):
+        try:
+            # Read streams
+            if type(stream.read(0)) is str:
+                return stream.buffer
+        except _UnsupportedOperation:
+            # Cannot read from the stream: try via writes
+            try:
+                stream.write(_b(''))
+            except TypeError:
+                return stream.buffer
+    return stream
index 222359ba842b3e0869d7a8a48212ee29bf5ceaee..3d63ff56326159e2a41f07a5f450ca0268db243f 100644 (file)
@@ -88,9 +88,9 @@ xfail todo
         self.run_tests(result_filter)
         tests_included = [
             event[1] for event in result._events if event[0] == 'startTest']
-        tests_expected = map(
+        tests_expected = list(map(
             subunit.RemotedTestCase,
-            ['passed', 'error', 'skipped', 'todo'])
+            ['passed', 'error', 'skipped', 'todo']))
         self.assertEquals(tests_expected, tests_included)
 
     def test_tags_tracked_correctly(self):
@@ -98,7 +98,7 @@ xfail todo
         result = ExtendedTestResult()
         result_filter = TestResultFilter(
             result, filter_success=False, filter_predicate=tag_filter)
-        input_stream = (
+        input_stream = _b(
             "test: foo\n"
             "tags: a\n"
             "successful: foo\n"
@@ -319,7 +319,7 @@ xfail todo
         return result._events
 
     def test_default(self):
-        output = self.run_command([], (
+        output = self.run_command([], _b(
                 "test: foo\n"
                 "skip: foo\n"
                 ))
@@ -332,7 +332,7 @@ xfail todo
             events)
 
     def test_tags(self):
-        output = self.run_command(['-s', '--with-tag', 'a'], (
+        output = self.run_command(['-s', '--with-tag', 'a'], _b(
                 "tags: a\n"
                 "test: foo\n"
                 "success: foo\n"