Refactor server code to allow custom handler classes.
authorDave Borowitz <dborowitz@google.com>
Mon, 5 Apr 2010 23:55:28 +0000 (16:55 -0700)
committerDave Borowitz <dborowitz@google.com>
Fri, 30 Apr 2010 16:42:54 +0000 (09:42 -0700)
Also changed to use the same default handler mapping in both TCP
and HTTP servers.

Change-Id: I2fb22768a58ca6f21888593467959bb4176e64e6

dulwich/server.py
dulwich/tests/test_web.py
dulwich/web.py

index 7634473..24dd906 100644 (file)
@@ -58,6 +58,7 @@ from dulwich.protocol import (
     )
 
 
+
 class Backend(object):
     """A backend for the Git smart server implementation."""
 
@@ -654,20 +655,26 @@ class ReceivePackHandler(Handler):
             self.proto.write_pkt_line(None)
 
 
+# Default handler classes for git services.
+DEFAULT_HANDLERS = {
+  'git-upload-pack': UploadPackHandler,
+  'git-receive-pack': ReceivePackHandler,
+  }
+
+
 class TCPGitRequestHandler(SocketServer.StreamRequestHandler):
 
+    def __init__(self, handlers, *args, **kwargs):
+        self.handlers = handlers and handlers or DEFAULT_HANDLERS
+        SocketServer.StreamRequestHandler.__init__(self, *args, **kwargs)
+
     def handle(self):
         proto = ReceivableProtocol(self.connection.recv, self.wfile.write)
         command, args = proto.read_cmd()
 
-        # switch case to handle the specific git command
-        if command == 'git-upload-pack':
-            cls = UploadPackHandler
-        elif command == 'git-receive-pack':
-            cls = ReceivePackHandler
-        else:
-            return
-
+        cls = self.handlers.get(command, None)
+        if not callable(cls):
+            raise GitProtocolError('Invalid service %s' % command)
         h = cls(self.server.backend, args, proto)
         h.handle()
 
@@ -677,6 +684,11 @@ class TCPGitServer(SocketServer.TCPServer):
     allow_reuse_address = True
     serve = SocketServer.TCPServer.serve_forever
 
-    def __init__(self, backend, listen_addr, port=TCP_GIT_PORT):
+    def _make_handler(self, *args, **kwargs):
+        return TCPGitRequestHandler(self.handlers, *args, **kwargs)
+
+    def __init__(self, backend, listen_addr, port=TCP_GIT_PORT, handlers=None):
         self.backend = backend
-        SocketServer.TCPServer.__init__(self, (listen_addr, port), TCPGitRequestHandler)
+        self.handlers = handlers
+        SocketServer.TCPServer.__init__(self, (listen_addr, port),
+                                        self._make_handler)
index 8319974..66c0422 100644 (file)
@@ -43,7 +43,8 @@ class WebTestCase(TestCase):
 
     def setUp(self):
         self._environ = {}
-        self._req = HTTPGitRequest(self._environ, self._start_response)
+        self._req = HTTPGitRequest(self._environ, self._start_response,
+                                   handlers=self._handlers())
         self._status = None
         self._headers = []
 
@@ -51,6 +52,9 @@ class WebTestCase(TestCase):
         self._status = status
         self._headers = list(headers)
 
+    def _handlers(self):
+        return None
+
 
 class DumbHandlersTestCase(WebTestCase):
 
@@ -177,7 +181,7 @@ class SmartHandlersTestCase(WebTestCase):
         self._handler = self._TestUploadPackHandler(*args, **kwargs)
         return self._handler
 
-    def services(self):
+    def _handlers(self):
         return {'git-upload-pack': self._make_handler}
 
     def test_handle_service_request_unknown(self):
@@ -188,8 +192,7 @@ class SmartHandlersTestCase(WebTestCase):
     def test_handle_service_request(self):
         self._environ['wsgi.input'] = StringIO('foo')
         mat = re.search('.*', '/git-upload-pack')
-        output = ''.join(handle_service_request(self._req, 'backend', mat,
-                                                services=self.services()))
+        output = ''.join(handle_service_request(self._req, 'backend', mat))
         self.assertEqual('handled input: foo', output)
         response_type = 'application/x-git-upload-pack-response'
         self.assertTrue(('Content-Type', response_type) in self._headers)
@@ -200,16 +203,14 @@ class SmartHandlersTestCase(WebTestCase):
         self._environ['wsgi.input'] = StringIO('foobar')
         self._environ['CONTENT_LENGTH'] = 3
         mat = re.search('.*', '/git-upload-pack')
-        output = ''.join(handle_service_request(self._req, 'backend', mat,
-                                                services=self.services()))
+        output = ''.join(handle_service_request(self._req, 'backend', mat))
         self.assertEqual('handled input: foo', output)
         response_type = 'application/x-git-upload-pack-response'
         self.assertTrue(('Content-Type', response_type) in self._headers)
 
     def test_get_info_refs_unknown(self):
         self._environ['QUERY_STRING'] = 'service=git-evil-handler'
-        list(get_info_refs(self._req, 'backend', None,
-                           services=self.services()))
+        list(get_info_refs(self._req, 'backend', None))
         self.assertEquals(HTTP_FORBIDDEN, self._status)
 
     def test_get_info_refs(self):
@@ -217,8 +218,7 @@ class SmartHandlersTestCase(WebTestCase):
         self._environ['QUERY_STRING'] = 'service=git-upload-pack'
 
         mat = re.search('.*', '/git-upload-pack')
-        output = ''.join(get_info_refs(self._req, 'backend', mat,
-                                       services=self.services()))
+        output = ''.join(get_info_refs(self._req, 'backend', mat))
         self.assertEquals(('001e# service=git-upload-pack\n'
                            '0000'
                            # input is ignored by the handler
index 9e8ed22..8e15453 100644 (file)
@@ -32,8 +32,11 @@ from dulwich.protocol import (
 from dulwich.server import (
     ReceivePackHandler,
     UploadPackHandler,
+    DEFAULT_HANDLERS,
     )
 
+
+# HTTP error strings
 HTTP_OK = '200 OK'
 HTTP_NOT_FOUND = '404 Not Found'
 HTTP_FORBIDDEN = '403 Forbidden'
@@ -128,15 +131,11 @@ def get_idx_file(req, backend, mat):
                      'application/x-git-packed-objects-toc')
 
 
-default_services = {'git-upload-pack': UploadPackHandler,
-                    'git-receive-pack': ReceivePackHandler}
-def get_info_refs(req, backend, mat, services=None):
-    if services is None:
-        services = default_services
+def get_info_refs(req, backend, mat):
     params = parse_qs(req.environ['QUERY_STRING'])
     service = params.get('service', [None])[0]
     if service and not req.dumb:
-        handler_cls = services.get(service, None)
+        handler_cls = req.handlers.get(service, None)
         if handler_cls is None:
             yield req.forbidden('Unsupported service %s' % service)
             return
@@ -202,11 +201,9 @@ class _LengthLimitedFile(object):
     # TODO: support more methods as necessary
 
 
-def handle_service_request(req, backend, mat, services=None):
-    if services is None:
-        services = default_services
+def handle_service_request(req, backend, mat):
     service = mat.group().lstrip('/')
-    handler_cls = services.get(service, None)
+    handler_cls = req.handlers.get(service, None)
     if handler_cls is None:
         yield req.forbidden('Unsupported service %s' % service)
         return
@@ -233,9 +230,10 @@ class HTTPGitRequest(object):
     :ivar environ: the WSGI environment for the request.
     """
 
-    def __init__(self, environ, start_response, dumb=False):
+    def __init__(self, environ, start_response, dumb=False, handlers=None):
         self.environ = environ
         self.dumb = dumb
+        self.handlers = handlers and handlers or DEFAULT_HANDLERS
         self._start_response = start_response
         self._cache_headers = []
         self._headers = []
@@ -304,14 +302,16 @@ class HTTPGitApplication(object):
       ('POST', re.compile('/git-receive-pack$')): handle_service_request,
     }
 
-    def __init__(self, backend, dumb=False):
+    def __init__(self, backend, dumb=False, handlers=None):
         self.backend = backend
         self.dumb = dumb
+        self.handlers = handlers
 
     def __call__(self, environ, start_response):
         path = environ['PATH_INFO']
         method = environ['REQUEST_METHOD']
-        req = HTTPGitRequest(environ, start_response, self.dumb)
+        req = HTTPGitRequest(environ, start_response, dumb=self.dumb,
+                             handlers=self.handlers)
         # environ['QUERY_STRING'] has qs args
         handler = None
         for smethod, spath in self.services.iterkeys():