Move the TEST_STARTED parser state to a state object.
authorRobert Collins <robertc@robertcollins.net>
Sat, 10 Oct 2009 19:24:25 +0000 (06:24 +1100)
committerRobert Collins <robertc@robertcollins.net>
Sat, 10 Oct 2009 19:24:25 +0000 (06:24 +1100)
python/subunit/__init__.py

index 911308e07c9f391492de00f56f20670d61195821..012555e6ad1e19b34a3ef24278ce53e8f3e002a2 100644 (file)
@@ -173,19 +173,126 @@ class DiscardStream(object):
         pass
 
 
-class _OutSideTest(object):
+class _ParserState(object):
     """State for the subunit parser."""
 
     def __init__(self, parser):
         self.parser = parser
 
+    def addError(self, offset, line):
+        """An 'error:' directive has been read."""
+        self.parser.stdOutLineReceived(line)
+
+    def addExpectedFail(self, offset, line):
+        """An 'xfail:' directive has been read."""
+        self.parser.stdOutLineReceived(line)
+
+    def addFailure(self, offset, line):
+        """A 'failure:' directive has been read."""
+        self.parser.stdOutLineReceived(line)
+
+    def addSkip(self, offset, line):
+        """A 'skip:' directive has been read."""
+        self.parser.stdOutLineReceived(line)
+
+    def addSuccess(self, offset, line):
+        """A 'success:' directive has been read."""
+        self.parser.stdOutLineReceived(line)
+
+    def startTest(self, offset, line):
+        """A test start command received."""
+        self.parser.stdOutLineReceived(line)
+
+
+class _InTest(_ParserState):
+    """State for the subunit parser after reading a test: directive."""
+
+    def addError(self, offset, line):
+        """An 'error:' directive has been read."""
+        if self.parser.current_test_description == line[offset:-1]:
+            self.parser._state = self.parser._outside_test
+            self.parser.state = TestProtocolServer.STATE_OBJECT
+            self.parser.current_test_description = None
+            self.parser.client.addError(self.parser._current_test, RemoteError(""))
+            self.parser.client.stopTest(self.parser._current_test)
+            self.parser._current_test = None
+        elif self.parser.current_test_description + " [" == line[offset:-1]:
+            self.parser.state = TestProtocolServer.READING_ERROR
+            self.parser._message = ""
+        else:
+            self.parser.stdOutLineReceived(line)
+
+    def addExpectedFail(self, offset, line):
+        """An 'xfail:' directive has been read."""
+        if self.parser.current_test_description == line[offset:-1]:
+            self.parser._state = self.parser._outside_test
+            self.parser.state = TestProtocolServer.STATE_OBJECT
+            self.parser.current_test_description = None
+            xfail = getattr(self.parser.client, 'addExpectedFailure', None)
+            if callable(xfail):
+                xfail(self.parser._current_test, RemoteError())
+            else:
+                self.parser.client.addSuccess(self.parser._current_test)
+            self.parser.client.stopTest(self.parser._current_test)
+        elif self.parser.current_test_description + " [" == line[offset:-1]:
+            self.parser.state = TestProtocolServer.READING_XFAIL
+            self.parser._message = ""
+        else:
+            self.parser.stdOutLineReceived(line)
+
+    def addFailure(self, offset, line):
+        """A 'failure:' directive has been read."""
+        if self.parser.current_test_description == line[offset:-1]:
+            self.parser._state = self.parser._outside_test
+            self.parser.state = TestProtocolServer.STATE_OBJECT
+            self.parser.current_test_description = None
+            self.parser.client.addFailure(self.parser._current_test, RemoteError())
+            self.parser.client.stopTest(self.parser._current_test)
+        elif self.parser.current_test_description + " [" == line[offset:-1]:
+            self.parser.state = TestProtocolServer.READING_FAILURE
+            self.parser._message = ""
+        else:
+            self.parser.stdOutLineReceived(line)
+
+    def addSkip(self, offset, line):
+        """A 'skip:' directive has been read."""
+        if self.parser.current_test_description == line[offset:-1]:
+            self.parser._state = self.parser._outside_test
+            self.parser.state = TestProtocolServer.STATE_OBJECT
+            self.parser.current_test_description = None
+            self.parser._skip_or_error()
+            self.parser.client.stopTest(self.parser._current_test)
+        elif self.parser.current_test_description + " [" == line[offset:-1]:
+            self.parser.state = TestProtocolServer.READING_SKIP
+            self.parser._message = ""
+        else:
+            self.parser.stdOutLineReceived(line)
+
+    def addSuccess(self, offset, line):
+        """A 'success:' directive has been read."""
+        if self.parser.current_test_description == line[offset:-1]:
+            self.parser._succeedTest()
+        elif self.parser.current_test_description + " [" == line[offset:-1]:
+            self.parser.state = TestProtocolServer.READING_SUCCESS
+            self.parser._message = ""
+        else:
+            self.parser.stdOutLineReceived(line)
+
+    def lostConnection(self):
+        """Connection lost."""
+        self.parser._lostConnectionInTest('')
+
+
+class _OutSideTest(_ParserState):
+    """State for the subunit parser outside of a test context."""
+
     def lostConnection(self):
         """Connection lost."""
 
     def startTest(self, offset, line):
         """A test start command received."""
-        self.parser.state = TestProtocolServer.TEST_STARTED
-        self.parser._state = None
+        self.parser._state = self.parser._in_test
+        self.parser.state = TestProtocolServer.STATE_OBJECT
         self.parser._current_test = RemotedTestCase(line[offset:-1])
         self.parser.current_test_description = line[offset:-1]
         self.parser.client.startTest(self.parser._current_test)
@@ -198,7 +305,7 @@ class TestProtocolServer(object):
     """
 
     STATE_OBJECT = 0
-    TEST_STARTED = 1
+    STATE_OBJECTS = [0]
     READING_FAILURE = 2
     READING_ERROR = 3
     READING_SKIP = 4
@@ -218,72 +325,34 @@ class TestProtocolServer(object):
         if stream is None:
             stream = sys.stdout
         self._stream = stream
+        # state objects we can switch too
+        self._in_test = _InTest(self)
         self._outside_test = _OutSideTest(self)
+        # start with outside test.
         self._state = self._outside_test
         self.state = TestProtocolServer.STATE_OBJECT
 
     def _addError(self, offset, line):
-        if (self.state == TestProtocolServer.TEST_STARTED and
-            self.current_test_description == line[offset:-1]):
-            self._state = self._outside_test
-            self.state = TestProtocolServer.STATE_OBJECT
-            self.current_test_description = None
-            self.client.addError(self._current_test, RemoteError(""))
-            self.client.stopTest(self._current_test)
-            self._current_test = None
-        elif (self.state == TestProtocolServer.TEST_STARTED and
-            self.current_test_description + " [" == line[offset:-1]):
-            self.state = TestProtocolServer.READING_ERROR
-            self._message = ""
+        if self.state in TestProtocolServer.STATE_OBJECTS:
+            self._state.addError(offset, line)
         else:
             self.stdOutLineReceived(line)
 
     def _addExpectedFail(self, offset, line):
-        if (self.state == TestProtocolServer.TEST_STARTED and
-            self.current_test_description == line[offset:-1]):
-            self._state = self._outside_test
-            self.state = TestProtocolServer.STATE_OBJECT
-            self.current_test_description = None
-            xfail = getattr(self.client, 'addExpectedFailure', None)
-            if callable(xfail):
-                xfail(self._current_test, RemoteError())
-            else:
-                self.client.addSuccess(self._current_test)
-            self.client.stopTest(self._current_test)
-        elif (self.state == TestProtocolServer.TEST_STARTED and
-            self.current_test_description + " [" == line[offset:-1]):
-            self.state = TestProtocolServer.READING_XFAIL
-            self._message = ""
+        if self.state in TestProtocolServer.STATE_OBJECTS:
+            self._state.addExpectedFail(offset, line)
         else:
             self.stdOutLineReceived(line)
 
     def _addFailure(self, offset, line):
-        if (self.state == TestProtocolServer.TEST_STARTED and
-            self.current_test_description == line[offset:-1]):
-            self._state = self._outside_test
-            self.state = TestProtocolServer.STATE_OBJECT
-            self.current_test_description = None
-            self.client.addFailure(self._current_test, RemoteError())
-            self.client.stopTest(self._current_test)
-        elif (self.state == TestProtocolServer.TEST_STARTED and
-            self.current_test_description + " [" == line[offset:-1]):
-            self.state = TestProtocolServer.READING_FAILURE
-            self._message = ""
+        if self.state in TestProtocolServer.STATE_OBJECTS:
+            self._state.addFailure(offset, line)
         else:
             self.stdOutLineReceived(line)
 
     def _addSkip(self, offset, line):
-        if (self.state == TestProtocolServer.TEST_STARTED and
-            self.current_test_description == line[offset:-1]):
-            self._state = self._outside_test
-            self.state = TestProtocolServer.STATE_OBJECT
-            self.current_test_description = None
-            self._skip_or_error()
-            self.client.stopTest(self._current_test)
-        elif (self.state == TestProtocolServer.TEST_STARTED and
-            self.current_test_description + " [" == line[offset:-1]):
-            self.state = TestProtocolServer.READING_SKIP
-            self._message = ""
+        if self.state in TestProtocolServer.STATE_OBJECTS:
+            self._state.addSkip(offset, line)
         else:
             self.stdOutLineReceived(line)
 
@@ -298,13 +367,8 @@ class TestProtocolServer(object):
             addSkip(self._current_test, message)
 
     def _addSuccess(self, offset, line):
-        if (self.state == TestProtocolServer.TEST_STARTED and
-            self.current_test_description == line[offset:-1]):
-            self._succeedTest()
-        elif (self.state == TestProtocolServer.TEST_STARTED and
-            self.current_test_description + " [" == line[offset:-1]):
-            self.state = TestProtocolServer.READING_SUCCESS
-            self._message = ""
+        if self.state in TestProtocolServer.STATE_OBJECTS:
+            self._state.addSuccess(offset, line)
         else:
             self.stdOutLineReceived(line)
 
@@ -435,12 +499,10 @@ class TestProtocolServer(object):
 
     def lostConnection(self):
         """The input connection has finished."""
-        if self.state == TestProtocolServer.STATE_OBJECT:
+        if self.state in TestProtocolServer.STATE_OBJECTS:
             self._state.lostConnection()
             return
-        if self.state == TestProtocolServer.TEST_STARTED:
-            self._lostConnectionInTest('')
-        elif self.state == TestProtocolServer.READING_ERROR:
+        if self.state == TestProtocolServer.READING_ERROR:
             self._lostConnectionInTest('error report of ')
         elif self.state == TestProtocolServer.READING_FAILURE:
             self._lostConnectionInTest('failure report of ')
@@ -460,7 +522,7 @@ class TestProtocolServer(object):
 
     def _startTest(self, offset, line):
         """Internal call to change state machine. Override startTest()."""
-        if self.state == TestProtocolServer.STATE_OBJECT:
+        if self.state in TestProtocolServer.STATE_OBJECTS:
             self._state.startTest(offset, line)
         else:
             self.stdOutLineReceived(line)