Improve server protocol error handling; fix flush-pkt handling.
authorDave Borowitz <dborowitz@google.com>
Tue, 26 Jan 2010 17:51:29 +0000 (09:51 -0800)
committerDave Borowitz <dborowitz@google.com>
Tue, 9 Feb 2010 17:45:03 +0000 (09:45 -0800)
Change-Id: Ib3631cd3167071f555113f62d80bc63bf02acdad

dulwich/server.py
dulwich/tests/test_server.py

index 6e370f9e0fbaeccf77d411c51745f1fc440729f8..eff09ac6caed2917b419a43250d533c14efbe4ff 100644 (file)
 # MA  02110-1301, USA.
 
 
-"""Git smart network protocol server implementation."""
+"""Git smart network protocol server implementation.
+
+For more detailed implementation on the network protocol, see the
+Documentation/technical directory in the cgit distribution, and in particular:
+    Documentation/technical/protocol-capabilities.txt
+    Documentation/technical/pack-protocol.txt
+"""
 
 
 import collections
@@ -137,47 +143,13 @@ class UploadPackHandler(Handler):
                                    set_client_capabilities)
 
     def handle(self):
-        def determine_wants(heads):
-            keys = heads.keys()
-            values = set(heads.itervalues())
-            if keys:
-                self.proto.write_pkt_line("%s %s\x00%s\n" % ( heads[keys[0]], keys[0], self.capabilities()))
-                for k in keys[1:]:
-                    self.proto.write_pkt_line("%s %s\n" % (heads[k], k))
-
-            # i'm done..
-            self.proto.write("0000")
-
-            # Now client will either send "0000", meaning that it doesnt want to pull.
-            # or it will start sending want want want commands
-            want = self.proto.read_pkt_line()
-            if want == None:
-                return []
-
-            want, self.client_capabilities = extract_want_line_capabilities(want)
-            graph_walker.set_ack_type(ack_type(self.client_capabilities))
-
-            want_revs = []
-            while want and want[:4] == 'want':
-                sha = want[5:45]
-                try:
-                    hex_to_sha(sha)
-                except (TypeError, AssertionError), e:
-                    raise GitProtocolError(e)
-
-                if sha not in values:
-                    raise GitProtocolError(
-                        'Client wants invalid object %s' % sha)
-                want_revs.append(sha)
-                want = self.proto.read_pkt_line()
-            graph_walker.set_wants(want_revs)
-            return want_revs
 
         progress = lambda x: self.proto.write_sideband(2, x)
         write = lambda x: self.proto.write_sideband(1, x)
 
-        graph_walker = ProtocolGraphWalker(self.backend.object_store, self.proto)
-        objects_iter = self.backend.fetch_objects(determine_wants, graph_walker, progress)
+        graph_walker = ProtocolGraphWalker(self)
+        objects_iter = self.backend.fetch_objects(
+          graph_walker.determine_wants, graph_walker, progress)
 
         # Do they want any objects?
         if len(objects_iter) == 0:
@@ -205,15 +177,63 @@ class ProtocolGraphWalker(object):
     call to set_ack_level() is required to set up the implementation, before any
     calls to next() or ack() are made.
     """
-    def __init__(self, object_store, proto):
-        self.store = object_store
-        self.proto = proto
+    def __init__(self, handler):
+        self.handler = handler
+        self.store = handler.backend.object_store
+        self.proto = handler.proto
         self._wants = []
         self._cached = False
         self._cache = []
         self._cache_index = 0
         self._impl = None
 
+    def determine_wants(self, heads):
+        """Determine the wants for a set of heads.
+
+        The given heads are advertised to the client, who then specifies which
+        refs he wants using 'want' lines. This portion of the protocol is the
+        same regardless of ack type, and in fact is used to set the ack type of
+        the ProtocolGraphWalker.
+
+        :param heads: a dict of refname->SHA1 to advertise
+        :return: a list of SHA1s requested by the client
+        """
+        if not heads:
+            raise GitProtocolError('No heads found')
+        values = set(heads.itervalues())
+        for i, (ref, sha) in enumerate(heads.iteritems()):
+            line = "%s %s" % (sha, ref)
+            if not i:
+                line = "%s\x00%s" % (line, self.handler.capabilities())
+            self.proto.write_pkt_line("%s\n" % line)
+            # TODO: include peeled value of any tags
+
+        # i'm done..
+        self.proto.write_pkt_line(None)
+
+        # Now client will sending want want want commands
+        want = self.proto.read_pkt_line()
+        if not want:
+            return []
+        line, caps = extract_want_line_capabilities(want)
+        self.handler.client_capabilities = caps
+        self.set_ack_type(ack_type(caps))
+        command, sha = self._split_proto_line(line)
+
+        want_revs = []
+        while command != None:
+            if command != 'want':
+                raise GitProtocolError(
+                    'Protocol got unexpected command %s' % command)
+            if sha not in values:
+                raise GitProtocolError(
+                    'Client wants invalid object %s' % sha)
+            want_revs.append(sha)
+            command, sha = self.read_proto_line()
+
+        self.set_wants(want_revs)
+        return want_revs
+
     def ack(self, have_ref):
         return self._impl.ack(have_ref)
 
@@ -223,16 +243,31 @@ class ProtocolGraphWalker(object):
 
     def next(self):
         if not self._cached:
+            if not self._impl:
+                return None
             return self._impl.next()
         self._cache_index += 1
         if self._cache_index > len(self._cache):
             return None
         return self._cache[self._cache_index]
 
+    def _split_proto_line(self, line):
+        fields = line.rstrip('\n').split(' ', 1)
+        if len(fields) == 1 and fields[0] == 'done':
+            return ('done', None)
+        elif len(fields) == 2 and fields[0] in ('want', 'have'):
+            try:
+                hex_to_sha(fields[1])
+                return tuple(fields)
+            except (TypeError, AssertionError), e:
+                raise GitProtocolError(e)
+        raise GitProtocolError('Received invalid line from client:\n%s' % line)
+
     def read_proto_line(self):
         """Read a line from the wire.
 
         :return: a tuple having one of the following forms:
+            ('want', obj_id)
             ('have', obj_id)
             ('done', None)
             (None, None)  (for a flush-pkt)
@@ -240,16 +275,7 @@ class ProtocolGraphWalker(object):
         line = self.proto.read_pkt_line()
         if not line:
             return (None, None)
-        fields = line.rstrip('\n').split(' ', 1)
-        if len(fields) == 1 and fields[0] == 'done':
-            return ('done', None)
-        if len(fields) == 2 and fields[0] == 'have':
-            try:
-                hex_to_sha(fields[1])
-                return fields
-            except (TypeError, AssertionError), e:
-                raise GitProtocolError(e)
-        raise GitProtocolError('Received invalid line from client:\n%s' % line)
+        return self._split_proto_line(line)
 
     def send_ack(self, sha, ack_type=''):
         if ack_type:
@@ -351,23 +377,26 @@ class MultiAckGraphWalkerImpl(object):
         # else we blind ack within next
 
     def next(self):
-        command, sha = self.walker.read_proto_line()
-        if command is None:
-            self.walker.send_nak()
-            return None
-        elif command == 'done':
-            # don't nak unless no common commits were found, even if not
-            # everything is satisfied
-            if self._common:
-                self.walker.send_ack(self._common[-1])
-            else:
+        while True:
+            command, sha = self.walker.read_proto_line()
+            if command is None:
                 self.walker.send_nak()
-            return None
-        elif command == 'have':
-            if self._found_base:
-                # blind ack
-                self.walker.send_ack(sha, 'continue')
-            return sha
+                # in multi-ack mode, a flush-pkt indicates the client wants to
+                # flush but more have lines are still coming
+                continue
+            elif command == 'done':
+                # don't nak unless no common commits were found, even if not
+                # everything is satisfied
+                if self._common:
+                    self.walker.send_ack(self._common[-1])
+                else:
+                    self.walker.send_nak()
+                return None
+            elif command == 'have':
+                if self._found_base:
+                    # blind ack
+                    self.walker.send_ack(sha, 'continue')
+                return sha
 
 
 class ReceivePackHandler(Handler):
index 13c886ee0bf1701c8064eb272712dd5648db4e30..71df31910f546c3b2bc3e4eba72d2caa9c44a2f7 100644 (file)
@@ -113,6 +113,20 @@ class TestCommit(object):
         return '%s(%s)' % (self.__class__.__name__, self._sha)
 
 
+class TestBackend(object):
+    def __init__(self, objects):
+        self.object_store = objects
+
+
+class TestHandler(object):
+    def __init__(self, objects, proto):
+        self.backend = TestBackend(objects)
+        self.proto = proto
+
+    def capabilities(self):
+        return 'multi_ack'
+
+
 class ProtocolGraphWalkerTestCase(TestCase):
     def setUp(self):
         # Create the following commit tree:
@@ -126,7 +140,8 @@ class ProtocolGraphWalkerTestCase(TestCase):
             FOUR: TestCommit(FOUR, [TWO], 444),
             FIVE: TestCommit(FIVE, [THREE], 555),
             }
-        self._walker = ProtocolGraphWalker(self._objects, None)
+        self._walker = ProtocolGraphWalker(
+            TestHandler(self._objects, TestProto()))
 
     def test_is_satisfied_no_haves(self):
         self.assertFalse(self._walker._is_satisfied([], ONE, 0))
@@ -153,6 +168,45 @@ class ProtocolGraphWalkerTestCase(TestCase):
         self.assertFalse(self._walker.all_wants_satisfied([THREE]))
         self.assertTrue(self._walker.all_wants_satisfied([TWO, THREE]))
 
+    def test_read_proto_line(self):
+        self._walker.proto.set_output([
+            'want %s' % ONE,
+            'want %s' % TWO,
+            'have %s' % THREE,
+            'foo %s' % FOUR,
+            'bar',
+            'done',
+            ])
+        self.assertEquals(('want', ONE), self._walker.read_proto_line())
+        self.assertEquals(('want', TWO), self._walker.read_proto_line())
+        self.assertEquals(('have', THREE), self._walker.read_proto_line())
+        self.assertRaises(GitProtocolError, self._walker.read_proto_line)
+        self.assertRaises(GitProtocolError, self._walker.read_proto_line)
+        self.assertEquals(('done', None), self._walker.read_proto_line())
+        self.assertEquals((None, None), self._walker.read_proto_line())
+
+    def test_determine_wants(self):
+        self.assertRaises(GitProtocolError, self._walker.determine_wants, {})
+
+        self._walker.proto.set_output([
+            'want %s multi_ack' % ONE,
+            'want %s' % TWO,
+            ])
+        heads = {'ref1': ONE, 'ref2': TWO, 'ref3': THREE}
+        self.assertEquals([ONE, TWO], self._walker.determine_wants(heads))
+
+        self._walker.proto.set_output(['want %s multi_ack' % FOUR])
+        self.assertRaises(GitProtocolError, self._walker.determine_wants, heads)
+
+        self._walker.proto.set_output([])
+        self.assertRaises(GitProtocolError, self._walker.determine_wants, heads)
+
+        self._walker.proto.set_output(['want %s multi_ack' % ONE, 'foo'])
+        self.assertRaises(GitProtocolError, self._walker.determine_wants, heads)
+
+        self._walker.proto.set_output(['want %s multi_ack' % FOUR])
+        self.assertRaises(GitProtocolError, self._walker.determine_wants, heads)
+
     # TODO: test commit time cutoff
 
 
@@ -307,13 +361,19 @@ class MultiAckGraphWalkerImplTestCase(AckGraphWalkerImplTestCase):
         self.assertAck(ONE)
 
     def test_multi_ack_flush(self):
-        # same as ack test but ends with a flush-pkt instead of done
-        self._walker.lines[-1] = (None, None)
-
+        self._walker.lines = [
+            ('have', TWO),
+            (None, None),
+            ('have', ONE),
+            ('have', THREE),
+            ('done', None),
+            ]
         self.assertNextEquals(TWO)
         self.assertNoAck()
 
         self.assertNextEquals(ONE)
+        self.assertNak() # nak the flush-pkt
+
         self._walker.done = True
         self._impl.ack(ONE)
         self.assertAck(ONE, 'continue')
@@ -323,7 +383,7 @@ class MultiAckGraphWalkerImplTestCase(AckGraphWalkerImplTestCase):
         self.assertAck(THREE, 'continue')
 
         self.assertNextEquals(None)
-        self.assertNak()
+        self.assertAck(THREE)
 
     def test_multi_ack_nak(self):
         self.assertNextEquals(TWO)
@@ -337,19 +397,3 @@ class MultiAckGraphWalkerImplTestCase(AckGraphWalkerImplTestCase):
 
         self.assertNextEquals(None)
         self.assertNak()
-
-    def test_multi_ack_nak_flush(self):
-        # same as nak test but ends with a flush-pkt instead of done
-        self._walker.lines[-1] = (None, None)
-
-        self.assertNextEquals(TWO)
-        self.assertNoAck()
-
-        self.assertNextEquals(ONE)
-        self.assertNoAck()
-
-        self.assertNextEquals(THREE)
-        self.assertNoAck()
-
-        self.assertNextEquals(None)
-        self.assertNak()