server: Explicitly specify allowed protocol commands.
[jelmer/dulwich-libgit2.git] / dulwich / tests / test_server.py
index bc7d2386dd318c155271880e2fe5415c79c6d010..7fafac63319c57c70eecea117c43c8772350a6aa 100644 (file)
 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
 # MA  02110-1301, USA.
 
-
 """Tests for the smart protocol server."""
 
 
-from unittest import TestCase
-
 from dulwich.errors import (
     GitProtocolError,
+    UnexpectedCommandError,
     )
 from dulwich.server import (
     Backend,
@@ -32,10 +30,13 @@ from dulwich.server import (
     Handler,
     MultiAckGraphWalkerImpl,
     MultiAckDetailedGraphWalkerImpl,
+    _split_proto_line,
     ProtocolGraphWalker,
     SingleAckGraphWalkerImpl,
     UploadPackHandler,
     )
+from dulwich.tests import TestCase
+
 
 
 ONE = '1' * 40
@@ -45,6 +46,7 @@ FOUR = '4' * 40
 FIVE = '5' * 40
 SIX = '6' * 40
 
+
 class TestProto(object):
 
     def __init__(self):
@@ -76,12 +78,25 @@ class TestProto(object):
             return None
 
 
+class TestGenericHandler(Handler):
+
+    def __init__(self):
+        Handler.__init__(self, Backend(), None)
+
+    @classmethod
+    def capabilities(cls):
+        return ('cap1', 'cap2', 'cap3')
+
+    @classmethod
+    def required_capabilities(cls):
+        return ('cap2',)
+
+
 class HandlerTestCase(TestCase):
 
     def setUp(self):
-        self._handler = Handler(Backend(), None, None)
-        self._handler.capabilities = lambda: ('cap1', 'cap2', 'cap3')
-        self._handler.required_capabilities = lambda: ('cap2',)
+        super(HandlerTestCase, self).setUp()
+        self._handler = TestGenericHandler()
 
     def assertSucceeds(self, func, *args, **kwargs):
         try:
@@ -122,6 +137,7 @@ class HandlerTestCase(TestCase):
 class UploadPackHandlerTestCase(TestCase):
 
     def setUp(self):
+        super(UploadPackHandlerTestCase, self).setUp()
         self._backend = DictBackend({"/": BackendRepo()})
         self._handler = UploadPackHandler(self._backend,
                 ["/", "host=lolcathost"], None, None)
@@ -175,11 +191,9 @@ class TestCommit(object):
 
     def __init__(self, sha, parents, commit_time):
         self.id = sha
-        self._parents = parents
+        self.parents = parents
         self.commit_time = commit_time
-
-    def get_parents(self):
-        return self._parents
+        self.type_name = "commit"
 
     def __repr__(self):
         return '%s(%s)' % (self.__class__.__name__, self._sha)
@@ -208,24 +222,26 @@ class TestUploadPackHandler(Handler):
         self.stateless_rpc = False
         self.advertise_refs = False
 
-    def capabilities(self):
+    @classmethod
+    def capabilities(cls):
         return ('multi_ack',)
 
 
 class ProtocolGraphWalkerTestCase(TestCase):
 
     def setUp(self):
+        super(ProtocolGraphWalkerTestCase, self).setUp()
         # Create the following commit tree:
         #   3---5
         #  /
         # 1---2---4
         self._objects = {
-            ONE: TestCommit(ONE, [], 111),
-            TWO: TestCommit(TWO, [ONE], 222),
-            THREE: TestCommit(THREE, [ONE], 333),
-            FOUR: TestCommit(FOUR, [TWO], 444),
-            FIVE: TestCommit(FIVE, [THREE], 555),
-            }
+          ONE: TestCommit(ONE, [], 111),
+          TWO: TestCommit(TWO, [ONE], 222),
+          THREE: TestCommit(THREE, [ONE], 333),
+          FOUR: TestCommit(FOUR, [TWO], 444),
+          FIVE: TestCommit(FIVE, [THREE], 555),
+          }
 
         self._walker = ProtocolGraphWalker(
             TestUploadPackHandler(self._objects, TestProto()),
@@ -256,30 +272,29 @@ class ProtocolGraphWalkerTestCase(TestCase):
         self.assertFalse(self._walker.all_wants_satisfied([THREE]))
         self.assertTrue(self._walker.all_wants_satisfied([TWO, THREE]))
 
-    def test_read_proto_line(self):
-        self._walker.proto.set_output([
-            'want %s' % ONE,
-            'want %s' % TWO,
-            'have %s' % THREE,
-            'foo %s' % FOUR,
-            'bar',
-            'done',
-            ])
-        self.assertEquals(('want', ONE), self._walker.read_proto_line())
-        self.assertEquals(('want', TWO), self._walker.read_proto_line())
-        self.assertEquals(('have', THREE), self._walker.read_proto_line())
-        self.assertRaises(GitProtocolError, self._walker.read_proto_line)
-        self.assertRaises(GitProtocolError, self._walker.read_proto_line)
-        self.assertEquals(('done', None), self._walker.read_proto_line())
-        self.assertEquals((None, None), self._walker.read_proto_line())
+    def test_split_proto_line(self):
+        allowed = ('want', 'done', None)
+        self.assertEquals(('want', ONE),
+                          _split_proto_line('want %s\n' % ONE, allowed))
+        self.assertEquals(('want', TWO),
+                          _split_proto_line('want %s\n' % TWO, allowed))
+        self.assertRaises(GitProtocolError, _split_proto_line,
+                          'want xxxx\n', allowed)
+        self.assertRaises(UnexpectedCommandError, _split_proto_line,
+                          'have %s\n' % THREE, allowed)
+        self.assertRaises(GitProtocolError, _split_proto_line,
+                          'foo %s\n' % FOUR, allowed)
+        self.assertRaises(GitProtocolError, _split_proto_line, 'bar', allowed)
+        self.assertEquals(('done', None), _split_proto_line('done\n', allowed))
+        self.assertEquals((None, None), _split_proto_line('', allowed))
 
     def test_determine_wants(self):
         self.assertRaises(GitProtocolError, self._walker.determine_wants, {})
 
         self._walker.proto.set_output([
-            'want %s multi_ack' % ONE,
-            'want %s' % TWO,
-            ])
+          'want %s multi_ack' % ONE,
+          'want %s' % TWO,
+          ])
         heads = {'ref1': ONE, 'ref2': TWO, 'ref3': THREE}
         self._walker.get_peeled = heads.get
         self.assertEquals([ONE, TWO], self._walker.determine_wants(heads))
@@ -314,11 +329,11 @@ class ProtocolGraphWalkerTestCase(TestCase):
             lines.append(line.rstrip())
 
         self.assertEquals([
-            '%s ref4' % FOUR,
-            '%s ref5' % FIVE,
-            '%s tag6^{}' % FIVE,
-            '%s tag6' % SIX,
-            ], sorted(lines))
+          '%s ref4' % FOUR,
+          '%s ref5' % FIVE,
+          '%s tag6^{}' % FIVE,
+          '%s tag6' % SIX,
+          ], sorted(lines))
 
         # ensure peeled tag was advertised immediately following tag
         for i, line in enumerate(lines):
@@ -337,8 +352,11 @@ class TestProtocolGraphWalker(object):
         self.stateless_rpc = False
         self.advertise_refs = False
 
-    def read_proto_line(self):
-        return self.lines.pop(0)
+    def read_proto_line(self, allowed):
+        command, sha = self.lines.pop(0)
+        if allowed is not None:
+            assert command in allowed
+        return command, sha
 
     def send_ack(self, sha, ack_type=''):
         self.acks.append((sha, ack_type))
@@ -359,13 +377,14 @@ class AckGraphWalkerImplTestCase(TestCase):
     """Base setup and asserts for AckGraphWalker tests."""
 
     def setUp(self):
+        super(AckGraphWalkerImplTestCase, self).setUp()
         self._walker = TestProtocolGraphWalker()
         self._walker.lines = [
-            ('have', TWO),
-            ('have', ONE),
-            ('have', THREE),
-            ('done', None),
-            ]
+          ('have', TWO),
+          ('have', ONE),
+          ('have', THREE),
+          ('done', None),
+          ]
         self._impl = self.impl_cls(self._walker)
 
     def assertNoAck(self):
@@ -453,6 +472,7 @@ class SingleAckGraphWalkerImplTestCase(AckGraphWalkerImplTestCase):
         self.assertNextEquals(None)
         self.assertNak()
 
+
 class MultiAckGraphWalkerImplTestCase(AckGraphWalkerImplTestCase):
 
     impl_cls = MultiAckGraphWalkerImpl
@@ -490,17 +510,17 @@ class MultiAckGraphWalkerImplTestCase(AckGraphWalkerImplTestCase):
 
     def test_multi_ack_flush(self):
         self._walker.lines = [
-            ('have', TWO),
-            (None, None),
-            ('have', ONE),
-            ('have', THREE),
-            ('done', None),
-            ]
+          ('have', TWO),
+          (None, None),
+          ('have', ONE),
+          ('have', THREE),
+          ('done', None),
+          ]
         self.assertNextEquals(TWO)
         self.assertNoAck()
 
         self.assertNextEquals(ONE)
-        self.assertNak() # nak the flush-pkt
+        self.assertNak()  # nak the flush-pkt
 
         self._walker.done = True
         self._impl.ack(ONE)
@@ -565,17 +585,17 @@ class MultiAckDetailedGraphWalkerImplTestCase(AckGraphWalkerImplTestCase):
     def test_multi_ack_flush(self):
         # same as ack test but contains a flush-pkt in the middle
         self._walker.lines = [
-            ('have', TWO),
-            (None, None),
-            ('have', ONE),
-            ('have', THREE),
-            ('done', None),
-            ]
+          ('have', TWO),
+          (None, None),
+          ('have', ONE),
+          ('have', THREE),
+          ('done', None),
+          ]
         self.assertNextEquals(TWO)
         self.assertNoAck()
 
         self.assertNextEquals(ONE)
-        self.assertNak() # nak the flush-pkt
+        self.assertNak()  # nak the flush-pkt
 
         self._walker.done = True
         self._impl.ack(ONE)
@@ -604,12 +624,12 @@ class MultiAckDetailedGraphWalkerImplTestCase(AckGraphWalkerImplTestCase):
     def test_multi_ack_nak_flush(self):
         # same as nak test but contains a flush-pkt in the middle
         self._walker.lines = [
-            ('have', TWO),
-            (None, None),
-            ('have', ONE),
-            ('have', THREE),
-            ('done', None),
-            ]
+          ('have', TWO),
+          (None, None),
+          ('have', ONE),
+          ('have', THREE),
+          ('done', None),
+          ]
         self.assertNextEquals(TWO)
         self.assertNoAck()