server: Explicitly specify allowed protocol commands.
[jelmer/dulwich-libgit2.git] / dulwich / tests / test_server.py
index 04b6f8b5d0f28bb709c6d9ef7198d29ffa70041d..7fafac63319c57c70eecea117c43c8772350a6aa 100644 (file)
@@ -21,6 +21,7 @@
 
 from dulwich.errors import (
     GitProtocolError,
+    UnexpectedCommandError,
     )
 from dulwich.server import (
     Backend,
@@ -29,6 +30,7 @@ from dulwich.server import (
     Handler,
     MultiAckGraphWalkerImpl,
     MultiAckDetailedGraphWalkerImpl,
+    _split_proto_line,
     ProtocolGraphWalker,
     SingleAckGraphWalkerImpl,
     UploadPackHandler,
@@ -76,13 +78,25 @@ class TestProto(object):
             return None
 
 
+class TestGenericHandler(Handler):
+
+    def __init__(self):
+        Handler.__init__(self, Backend(), None)
+
+    @classmethod
+    def capabilities(cls):
+        return ('cap1', 'cap2', 'cap3')
+
+    @classmethod
+    def required_capabilities(cls):
+        return ('cap2',)
+
+
 class HandlerTestCase(TestCase):
 
     def setUp(self):
         super(HandlerTestCase, self).setUp()
-        self._handler = Handler(Backend(), None)
-        self._handler.capabilities = lambda: ('cap1', 'cap2', 'cap3')
-        self._handler.required_capabilities = lambda: ('cap2',)
+        self._handler = TestGenericHandler()
 
     def assertSucceeds(self, func, *args, **kwargs):
         try:
@@ -208,7 +222,8 @@ class TestUploadPackHandler(Handler):
         self.stateless_rpc = False
         self.advertise_refs = False
 
-    def capabilities(self):
+    @classmethod
+    def capabilities(cls):
         return ('multi_ack',)
 
 
@@ -257,22 +272,21 @@ class ProtocolGraphWalkerTestCase(TestCase):
         self.assertFalse(self._walker.all_wants_satisfied([THREE]))
         self.assertTrue(self._walker.all_wants_satisfied([TWO, THREE]))
 
-    def test_read_proto_line(self):
-        self._walker.proto.set_output([
-          'want %s' % ONE,
-          'want %s' % TWO,
-          'have %s' % THREE,
-          'foo %s' % FOUR,
-          'bar',
-          'done',
-          ])
-        self.assertEquals(('want', ONE), self._walker.read_proto_line())
-        self.assertEquals(('want', TWO), self._walker.read_proto_line())
-        self.assertEquals(('have', THREE), self._walker.read_proto_line())
-        self.assertRaises(GitProtocolError, self._walker.read_proto_line)
-        self.assertRaises(GitProtocolError, self._walker.read_proto_line)
-        self.assertEquals(('done', None), self._walker.read_proto_line())
-        self.assertEquals((None, None), self._walker.read_proto_line())
+    def test_split_proto_line(self):
+        allowed = ('want', 'done', None)
+        self.assertEquals(('want', ONE),
+                          _split_proto_line('want %s\n' % ONE, allowed))
+        self.assertEquals(('want', TWO),
+                          _split_proto_line('want %s\n' % TWO, allowed))
+        self.assertRaises(GitProtocolError, _split_proto_line,
+                          'want xxxx\n', allowed)
+        self.assertRaises(UnexpectedCommandError, _split_proto_line,
+                          'have %s\n' % THREE, allowed)
+        self.assertRaises(GitProtocolError, _split_proto_line,
+                          'foo %s\n' % FOUR, allowed)
+        self.assertRaises(GitProtocolError, _split_proto_line, 'bar', allowed)
+        self.assertEquals(('done', None), _split_proto_line('done\n', allowed))
+        self.assertEquals((None, None), _split_proto_line('', allowed))
 
     def test_determine_wants(self):
         self.assertRaises(GitProtocolError, self._walker.determine_wants, {})
@@ -338,8 +352,11 @@ class TestProtocolGraphWalker(object):
         self.stateless_rpc = False
         self.advertise_refs = False
 
-    def read_proto_line(self):
-        return self.lines.pop(0)
+    def read_proto_line(self, allowed):
+        command, sha = self.lines.pop(0)
+        if allowed is not None:
+            assert command in allowed
+        return command, sha
 
     def send_ack(self, sha, ack_type=''):
         self.acks.append((sha, ack_type))