Refactor server capability code into base Handler.
authorDave Borowitz <dborowitz@google.com>
Fri, 19 Feb 2010 21:33:38 +0000 (13:33 -0800)
committerDave Borowitz <dborowitz@google.com>
Thu, 4 Mar 2010 17:50:05 +0000 (09:50 -0800)
UploadPackHandler and ReceivePackHandler now both handle client
capabilities using a consistent interface, the set_client_capabilites
and has_capability functions. Both now error as soon as an unknown
capability is requested by the client.

Also renames the following methods:
  capabilities -> capability_line
  default_capabilities -> capabilities
This is because capability_line is the less general of the two
methods, as it is only useful when advertising capabilities to the
client.

Changed capabilities tests to use the base class and test new
functionality.

Change-Id: If7d3feeac27834119d6d4e4021569401e5444d51

dulwich/server.py
dulwich/tests/test_server.py

index 2e19838d402790a49b86682b7a1525ac075b2b27..c31e952cf24e22f631299ef93089f78c6c6272d1 100644 (file)
@@ -152,9 +152,27 @@ class Handler(object):
     def __init__(self, backend, read, write):
         self.backend = backend
         self.proto = Protocol(read, write)
+        self._client_capabilities = None
+
+    def capability_line(self):
+        return " ".join(self.capabilities())
 
     def capabilities(self):
-        return " ".join(self.default_capabilities())
+        raise NotImplementedError(self.capabilities)
+
+    def set_client_capabilities(self, caps):
+        my_caps = self.capabilities()
+        for cap in caps:
+            if cap not in my_caps:
+                raise GitProtocolError('Client asked for capability %s that '
+                                       'was not advertised.' % cap)
+        self._client_capabilities = caps
+
+    def has_capability(self, cap):
+        if self._client_capabilities is None:
+            raise GitProtocolError('Server attempted to access capability %s '
+                                   'before asking client' % cap)
+        return cap in self._client_capabilities
 
 
 class UploadPackHandler(Handler):
@@ -163,29 +181,14 @@ class UploadPackHandler(Handler):
     def __init__(self, backend, read, write,
                  stateless_rpc=False, advertise_refs=False):
         Handler.__init__(self, backend, read, write)
-        self._client_capabilities = None
         self._graph_walker = None
         self.stateless_rpc = stateless_rpc
         self.advertise_refs = advertise_refs
 
-    def default_capabilities(self):
+    def capabilities(self):
         return ("multi_ack_detailed", "multi_ack", "side-band-64k", "thin-pack",
                 "ofs-delta")
 
-    def set_client_capabilities(self, caps):
-        my_caps = self.default_capabilities()
-        for cap in caps:
-            if '_ack' in cap and cap not in my_caps:
-                raise GitProtocolError('Client asked for capability %s that '
-                                       'was not advertised.' % cap)
-        self._client_capabilities = caps
-
-    def get_client_capabilities(self):
-        return self._client_capabilities
-
-    client_capabilities = property(get_client_capabilities,
-                                   set_client_capabilities)
-
     def handle(self):
 
         progress = lambda x: self.proto.write_sideband(2, x)
@@ -251,7 +254,7 @@ class ProtocolGraphWalker(object):
             for i, (ref, sha) in enumerate(heads.iteritems()):
                 line = "%s %s" % (sha, ref)
                 if not i:
-                    line = "%s\x00%s" % (line, self.handler.capabilities())
+                    line = "%s\x00%s" % (line, self.handler.capability_line())
                 self.proto.write_pkt_line("%s\n" % line)
                 # TODO: include peeled value of any tags
 
@@ -266,7 +269,7 @@ class ProtocolGraphWalker(object):
         if not want:
             return []
         line, caps = extract_want_line_capabilities(want)
-        self.handler.client_capabilities = caps
+        self.handler.set_client_capabilities(caps)
         self.set_ack_type(ack_type(caps))
         command, sha = self._split_proto_line(line)
 
@@ -509,7 +512,7 @@ class ReceivePackHandler(Handler):
         self._stateless_rpc = stateless_rpc
         self._advertise_refs = advertise_refs
 
-    def default_capabilities(self):
+    def capabilities(self):
         return ("report-status", "delete-refs")
 
     def handle(self):
@@ -517,12 +520,14 @@ class ReceivePackHandler(Handler):
 
         if self.advertise_refs or not self.stateless_rpc:
             if refs:
-                self.proto.write_pkt_line("%s %s\x00%s\n" % (refs[0][1], refs[0][0], self.capabilities()))
+                self.proto.write_pkt_line(
+                    "%s %s\x00%s\n" % (refs[0][1], refs[0][0],
+                                       self.capability_line()))
                 for i in range(1, len(refs)):
                     ref = refs[i]
                     self.proto.write_pkt_line("%s %s\n" % (ref[1], ref[0]))
             else:
-                self.proto.write_pkt_line("0000000000000000000000000000000000000000 capabilities^{} %s" % self.capabilities())
+                self.proto.write_pkt_line("0000000000000000000000000000000000000000 capabilities^{} %s" % self.capability_line())
 
             self.proto.write("0000")
             if self.advertise_refs:
@@ -535,7 +540,8 @@ class ReceivePackHandler(Handler):
         if ref is None:
             return
 
-        ref, client_capabilities = extract_capabilities(ref)
+        ref, caps = extract_capabilities(ref)
+        self.set_client_capabilities(caps)
 
         # client will now send us a list of (oldsha, newsha, ref)
         while ref:
@@ -547,7 +553,7 @@ class ReceivePackHandler(Handler):
 
         # when we have read all the pack from the client, send a status report
         # if the client asked for it
-        if 'report-status' in client_capabilities:
+        if self.has_capability('report-status'):
             for name, msg in status:
                 if name == 'unpack':
                     self.proto.write_pkt_line('unpack %s\n' % msg)
index 56a6dc23a142a2f795c911434a359e3f231c0197..90947f4bb9e3a3ba1ce341009f19fa646e8649a2 100644 (file)
@@ -28,6 +28,7 @@ from dulwich.errors import (
     )
 from dulwich.server import (
     UploadPackHandler,
+    Handler,
     ProtocolGraphWalker,
     SingleAckGraphWalkerImpl,
     MultiAckGraphWalkerImpl,
@@ -75,30 +76,36 @@ class TestProto(object):
             return None
 
 
-class UploadPackHandlerTestCase(TestCase):
+class HandlerTestCase(TestCase):
     def setUp(self):
-        self._handler = UploadPackHandler(None, None, None)
+        self._handler = Handler(None, None, None)
+        self._handler.capabilities = lambda: ('cap1', 'cap2', 'cap3')
 
-    def test_set_client_capabilities(self):
+    def assertSucceeds(self, func, *args, **kwargs):
         try:
-            self._handler.set_client_capabilities([])
+            func(*args, **kwargs)
         except GitProtocolError:
             self.fail()
 
-        try:
-            self._handler.set_client_capabilities([
-                'multi_ack', 'side-band-64k', 'thin-pack', 'ofs-delta'])
-        except GitProtocolError:
-            self.fail()
+    def test_capability_line(self):
+        self.assertEquals('cap1 cap2 cap3', self._handler.capability_line())
 
-    def test_set_client_capabilities_error(self):
-        self.assertRaises(GitProtocolError,
-                          self._handler.set_client_capabilities,
-                          ['weird_ack_level', 'ofs-delta'])
-        try:
-            self._handler.set_client_capabilities(['include-tag'])
-        except GitProtocolError:
-            self.fail()
+    def test_set_client_capabilities(self):
+        set_caps = self._handler.set_client_capabilities
+        self.assertSucceeds(set_caps, [])
+        self.assertSucceeds(set_caps, ['cap2'])
+        self.assertSucceeds(set_caps, ['cap1', 'cap2'])
+        # different order
+        self.assertSucceeds(set_caps, ['cap3', 'cap1', 'cap2'])
+        self.assertRaises(GitProtocolError, set_caps, ['capxxx', 'cap1'])
+
+    def test_has_capability(self):
+        self.assertRaises(GitProtocolError, self._handler.has_capability, 'cap')
+        caps = self._handler.capabilities()
+        self._handler.set_client_capabilities(caps)
+        for cap in caps:
+            self.assertTrue(self._handler.has_capability(cap))
+        self.assertFalse(self._handler.has_capability('capxxx'))
 
 
 class TestCommit(object):
@@ -119,7 +126,7 @@ class TestBackend(object):
         self.object_store = objects
 
 
-class TestHandler(object):
+class TestUploadPackHandler(Handler):
     def __init__(self, objects, proto):
         self.backend = TestBackend(objects)
         self.proto = proto
@@ -127,7 +134,7 @@ class TestHandler(object):
         self.advertise_refs = False
 
     def capabilities(self):
-        return 'multi_ack'
+        return ('multi_ack',)
 
 
 class ProtocolGraphWalkerTestCase(TestCase):
@@ -144,7 +151,7 @@ class ProtocolGraphWalkerTestCase(TestCase):
             FIVE: TestCommit(FIVE, [THREE], 555),
             }
         self._walker = ProtocolGraphWalker(
-            TestHandler(self._objects, TestProto()))
+            TestUploadPackHandler(self._objects, TestProto()))
 
     def test_is_satisfied_no_haves(self):
         self.assertFalse(self._walker._is_satisfied([], ONE, 0))