subunit: Update to latest upstream version.
[samba.git] / lib / subunit / python / subunit / __init__.py
index b4c939756f268277e07ff92fff000eed7b69d424..6015c0e68ca050237178c42efa0b8094668091ea 100644 (file)
@@ -121,8 +121,14 @@ import re
 import subprocess
 import sys
 import unittest
+if sys.version_info > (3, 0):
+    from io import UnsupportedOperation as _UnsupportedOperation
+else:
+    _UnsupportedOperation = AttributeError
+
 
 from testtools import content, content_type, ExtendedToOriginalDecorator
+from testtools.content import TracebackContent
 from testtools.compat import _b, _u, BytesIO, StringIO
 try:
     from testtools.testresult.real import _StringException
@@ -182,9 +188,15 @@ def tags_to_new_gone(tags):
 class DiscardStream(object):
     """A filelike object which discards what is written to it."""
 
+    def fileno(self):
+        raise _UnsupportedOperation()
+
     def write(self, bytes):
         pass
 
+    def read(self, len=0):
+        return _b('')
+
 
 class _ParserState(object):
     """State for the subunit parser."""
@@ -599,8 +611,8 @@ class TestProtocolClient(testresult.TestResult):
 
     def __init__(self, stream):
         testresult.TestResult.__init__(self)
+        stream = _make_stream_binary(stream)
         self._stream = stream
-        _make_stream_binary(stream)
         self._progress_fmt = _b("progress: ")
         self._bytes_eol = _b("\n")
         self._progress_plus = _b("+")
@@ -682,10 +694,9 @@ class TestProtocolClient(testresult.TestResult):
                 raise ValueError
         if error is not None:
             self._stream.write(self._start_simple)
-            # XXX: this needs to be made much stricter, along the lines of
-            # Martin[gz]'s work in testtools. Perhaps subunit can use that?
-            for line in self._exc_info_to_unicode(error, test).splitlines():
-                self._stream.write(("%s\n" % line).encode('utf8'))
+            tb_content = TracebackContent(error, test)
+            for bytes in tb_content.iter_bytes():
+                self._stream.write(bytes)
         elif details is not None:
             self._write_details(details)
         else:
@@ -755,6 +766,15 @@ class TestProtocolClient(testresult.TestResult):
         self._stream.write(self._progress_fmt + prefix + offset +
             self._bytes_eol)
 
+    def tags(self, new_tags, gone_tags):
+        """Inform the client about tags added/removed from the stream."""
+        if not new_tags and not gone_tags:
+            return
+        tags = set([tag.encode('utf8') for tag in new_tags])
+        tags.update([_b("-") + tag.encode('utf8') for tag in gone_tags])
+        tag_line = _b("tags: ") + _b(" ").join(tags) + _b("\n")
+        self._stream.write(tag_line)
+
     def time(self, a_datetime):
         """Inform the client of the time.
 
@@ -1122,7 +1142,7 @@ class ProtocolTestCase(object):
     :seealso: TestProtocolServer (the subunit wire protocol parser).
     """
 
-    def __init__(self, stream, passthrough=None, forward=False):
+    def __init__(self, stream, passthrough=None, forward=None):
         """Create a ProtocolTestCase reading from stream.
 
         :param stream: A filelike object which a subunit stream can be read
@@ -1132,9 +1152,11 @@ class ProtocolTestCase(object):
         :param forward: A stream to pass subunit input on to. If not supplied
             subunit input is not forwarded.
         """
+        stream = _make_stream_binary(stream)
         self._stream = stream
-        _make_stream_binary(stream)
         self._passthrough = passthrough
+        if forward is not None:
+            forward = _make_stream_binary(forward)
         self._forward = forward
 
     def __call__(self, result=None):
@@ -1217,11 +1239,6 @@ def get_default_formatter():
         return stream
 
 
-if sys.version_info > (3, 0):
-    from io import UnsupportedOperation as _NoFilenoError
-else:
-    _NoFilenoError = AttributeError
-
 def read_test_list(path):
     """Read a list of test ids from a file on disk.
 
@@ -1236,15 +1253,37 @@ def read_test_list(path):
 
 
 def _make_stream_binary(stream):
-    """Ensure that a stream will be binary safe. See _make_binary_on_windows."""
+    """Ensure that a stream will be binary safe. See _make_binary_on_windows.
+    
+    :return: A binary version of the same stream (some streams cannot be
+        'fixed' but can be unwrapped).
+    """
     try:
         fileno = stream.fileno()
-    except _NoFilenoError:
-        return
-    _make_binary_on_windows(fileno)
+    except _UnsupportedOperation:
+        pass
+    else:
+        _make_binary_on_windows(fileno)
+    return _unwrap_text(stream)
 
 def _make_binary_on_windows(fileno):
     """Win32 mangles \r\n to \n and that breaks streams. See bug lp:505078."""
     if sys.platform == "win32":
         import msvcrt
         msvcrt.setmode(fileno, os.O_BINARY)
+
+
+def _unwrap_text(stream):
+    """Unwrap stream if it is a text stream to get the original buffer."""
+    if sys.version_info > (3, 0):
+        try:
+            # Read streams
+            if type(stream.read(0)) is str:
+                return stream.buffer
+        except (_UnsupportedOperation, IOError):
+            # Cannot read from the stream: try via writes
+            try:
+                stream.write(_b(''))
+            except TypeError:
+                return stream.buffer
+    return stream