server: Explicitly specify allowed protocol commands.
authorDave Borowitz <dborowitz@google.com>
Sun, 8 Aug 2010 19:32:08 +0000 (21:32 +0200)
committerJelmer Vernooij <jelmer@samba.org>
Sun, 8 Aug 2010 19:32:08 +0000 (21:32 +0200)
This means callers do not have to check the return value of
read_proto_line themselves, avoiding lots of duplicate code like
"if command not in (...): raise".

NEWS
dulwich/errors.py
dulwich/server.py
dulwich/tests/test_server.py

diff --git a/NEWS b/NEWS
index 9035100..f44d155 100644 (file)
--- a/NEWS
+++ b/NEWS
@@ -60,6 +60,9 @@
 
   * Clean up docstrings in dulwich.protocol. (Dave Borowitz)
 
+  * Explicitly specify allowed protocol commands to
+    ProtocolGraphWalker.read_proto_line.  (Dave Borowitz)
+
 
 0.6.1  2010-07-22
 
index 80f54b3..8fec614 100644 (file)
@@ -137,6 +137,17 @@ class HangupException(GitProtocolError):
             "The remote server unexpectedly closed the connection.")
 
 
+class UnexpectedCommandError(GitProtocolError):
+    """Unexpected command received in a proto line."""
+
+    def __init__(self, command):
+        if command is None:
+            command = 'flush-pkt'
+        else:
+            command = 'command %s' % command
+        GitProtocolError.__init__(self, 'Protocol got unexpected %s' % command)
+
+
 class FileFormatException(Exception):
     """Base class for exceptions relating to reading git file formats."""
 
index 28331e6..7a97e64 100644 (file)
@@ -36,6 +36,7 @@ from dulwich.errors import (
     ApplyDeltaError,
     ChecksumMismatch,
     GitProtocolError,
+    UnexpectedCommandError,
     ObjectFormatException,
     )
 from dulwich import log_utils
@@ -276,31 +277,39 @@ class UploadPackHandler(Handler):
         self.proto.write("0000")
 
 
-def _split_proto_line(line):
+def _split_proto_line(line, allowed):
     """Split a line read from the wire.
 
+    :param line: The line read from the wire.
+    :param allowed: An iterable of command names that should be allowed.
+        Command names not listed below as possible return values will be
+        ignored.  If None, any commands from the possible return values are
+        allowed.
     :return: a tuple having one of the following forms:
         ('want', obj_id)
         ('have', obj_id)
         ('done', None)
         (None, None)  (for a flush-pkt)
 
-    :raise GitProtocolError: if the line cannot be parsed into one of the
-        possible return values.
+    :raise UnexpectedCommandError: if the line cannot be parsed into one of the
+        allowed return values.
     """
     if not line:
         fields = [None]
     else:
         fields = line.rstrip('\n').split(' ', 1)
-    if len(fields) == 1 and fields[0] in ('done', None):
-        return (fields[0], None)
-    elif len(fields) == 2 and fields[0] in ('want', 'have'):
-        try:
+    command = fields[0]
+    if allowed is not None and command not in allowed:
+        raise UnexpectedCommandError(command)
+    try:
+        if len(fields) == 1 and command in ('done', None):
+            return (command, None)
+        elif len(fields) == 2 and command in ('want', 'have'):
             hex_to_sha(fields[1])
             return tuple(fields)
-        except (TypeError, AssertionError), e:
-            raise GitProtocolError(e)
-    raise GitProtocolError('Received invalid line from client:\n%s' % line)
+    except (TypeError, AssertionError), e:
+        raise GitProtocolError(e)
+    raise GitProtocolError('Received invalid line from client: %s' % line)
 
 
 class ProtocolGraphWalker(object):
@@ -367,18 +376,16 @@ class ProtocolGraphWalker(object):
         line, caps = extract_want_line_capabilities(want)
         self.handler.set_client_capabilities(caps)
         self.set_ack_type(ack_type(caps))
-        command, sha = _split_proto_line(line)
+        allowed = ('want', None)
+        command, sha = _split_proto_line(line, allowed)
 
         want_revs = []
         while command != None:
-            if command != 'want':
-                raise GitProtocolError(
-                  'Protocol got unexpected command %s' % command)
             if sha not in values:
                 raise GitProtocolError(
                   'Client wants invalid object %s' % sha)
             want_revs.append(sha)
-            command, sha = self.read_proto_line()
+            command, sha = self.read_proto_line(allowed)
 
         self.set_wants(want_revs)
         return want_revs
@@ -400,13 +407,14 @@ class ProtocolGraphWalker(object):
             return None
         return self._cache[self._cache_index]
 
-    def read_proto_line(self):
-        """Read and split a line from the wire.
+    def read_proto_line(self, allowed):
+        """Read a line from the wire.
 
+        :param allowed: An iterable of command names that should be allowed.
         :return: A tuple of (command, value); see _split_proto_line.
         :raise GitProtocolError: If an error occurred reading the line.
         """
-        return _split_proto_line(self.proto.read_pkt_line())
+        return _split_proto_line(self.proto.read_pkt_line(), allowed)
 
     def send_ack(self, sha, ack_type=''):
         if ack_type:
@@ -470,6 +478,9 @@ class ProtocolGraphWalker(object):
         self._impl = impl_classes[ack_type](self)
 
 
+_GRAPH_WALKER_COMMANDS = ('have', 'done', None)
+
+
 class SingleAckGraphWalkerImpl(object):
     """Graph walker implementation that speaks the single-ack protocol."""
 
@@ -483,7 +494,7 @@ class SingleAckGraphWalkerImpl(object):
             self._sent_ack = True
 
     def next(self):
-        command, sha = self.walker.read_proto_line()
+        command, sha = self.walker.read_proto_line(_GRAPH_WALKER_COMMANDS)
         if command in (None, 'done'):
             if not self._sent_ack:
                 self.walker.send_nak()
@@ -510,7 +521,7 @@ class MultiAckGraphWalkerImpl(object):
 
     def next(self):
         while True:
-            command, sha = self.walker.read_proto_line()
+            command, sha = self.walker.read_proto_line(_GRAPH_WALKER_COMMANDS)
             if command is None:
                 self.walker.send_nak()
                 # in multi-ack mode, a flush-pkt indicates the client wants to
@@ -550,7 +561,7 @@ class MultiAckDetailedGraphWalkerImpl(object):
 
     def next(self):
         while True:
-            command, sha = self.walker.read_proto_line()
+            command, sha = self.walker.read_proto_line(_GRAPH_WALKER_COMMANDS)
             if command is None:
                 self.walker.send_nak()
                 if self.walker.stateless_rpc:
index e66ab74..7fafac6 100644 (file)
@@ -21,6 +21,7 @@
 
 from dulwich.errors import (
     GitProtocolError,
+    UnexpectedCommandError,
     )
 from dulwich.server import (
     Backend,
@@ -272,15 +273,20 @@ class ProtocolGraphWalkerTestCase(TestCase):
         self.assertTrue(self._walker.all_wants_satisfied([TWO, THREE]))
 
     def test_split_proto_line(self):
-        self.assertEquals(('want', ONE), _split_proto_line('want %s\n' % ONE))
-        self.assertEquals(('want', TWO), _split_proto_line('want %s\n' % TWO))
-        self.assertEquals(('have', THREE),
-                          _split_proto_line('have %s\n' % THREE))
-        self.assertRaises(GitProtocolError,
-                          _split_proto_line, 'foo %s\n' % FOUR)
-        self.assertRaises(GitProtocolError, _split_proto_line, 'bar')
-        self.assertEquals(('done', None), _split_proto_line('done\n'))
-        self.assertEquals((None, None), _split_proto_line(''))
+        allowed = ('want', 'done', None)
+        self.assertEquals(('want', ONE),
+                          _split_proto_line('want %s\n' % ONE, allowed))
+        self.assertEquals(('want', TWO),
+                          _split_proto_line('want %s\n' % TWO, allowed))
+        self.assertRaises(GitProtocolError, _split_proto_line,
+                          'want xxxx\n', allowed)
+        self.assertRaises(UnexpectedCommandError, _split_proto_line,
+                          'have %s\n' % THREE, allowed)
+        self.assertRaises(GitProtocolError, _split_proto_line,
+                          'foo %s\n' % FOUR, allowed)
+        self.assertRaises(GitProtocolError, _split_proto_line, 'bar', allowed)
+        self.assertEquals(('done', None), _split_proto_line('done\n', allowed))
+        self.assertEquals((None, None), _split_proto_line('', allowed))
 
     def test_determine_wants(self):
         self.assertRaises(GitProtocolError, self._walker.determine_wants, {})
@@ -346,8 +352,11 @@ class TestProtocolGraphWalker(object):
         self.stateless_rpc = False
         self.advertise_refs = False
 
-    def read_proto_line(self):
-        return self.lines.pop(0)
+    def read_proto_line(self, allowed):
+        command, sha = self.lines.pop(0)
+        if allowed is not None:
+            assert command in allowed
+        return command, sha
 
     def send_ack(self, sha, ack_type=''):
         self.acks.append((sha, ack_type))