Clean up file headers.
[jelmer/dulwich-libgit2.git] / dulwich / server.py
index cf6ca25dd1a26cde797e24f9493475b736ccc741..171eb41c475ba9bbe58836460b54e1222d24ab17 100644 (file)
@@ -1,5 +1,5 @@
 # server.py -- Implementation of the server side git protocols
-# Copryight (C) 2008 John Carr <john.carr@unrouted.co.uk>
+# Copyright (C) 2008 John Carr <john.carr@unrouted.co.uk>
 #
 # This program is free software; you can redistribute it and/or
 # modify it under the terms of the GNU General Public License
@@ -27,32 +27,38 @@ Documentation/technical directory in the cgit distribution, and in particular:
 
 
 import collections
+import socket
+import zlib
 import SocketServer
 
 from dulwich.errors import (
     ApplyDeltaError,
     ChecksumMismatch,
     GitProtocolError,
+    ObjectFormatException,
     )
 from dulwich.objects import (
     hex_to_sha,
     )
+from dulwich.pack import (
+    PackStreamReader,
+    write_pack_data,
+    )
 from dulwich.protocol import (
-    Protocol,
+    MULTI_ACK,
+    MULTI_ACK_DETAILED,
     ProtocolFile,
+    ReceivableProtocol,
+    SINGLE_ACK,
     TCP_GIT_PORT,
     ZERO_SHA,
+    ack_type,
     extract_capabilities,
     extract_want_line_capabilities,
-    SINGLE_ACK,
-    MULTI_ACK,
-    MULTI_ACK_DETAILED,
-    ack_type,
-    )
-from dulwich.pack import (
-    write_pack_data,
     )
 
+
+
 class Backend(object):
     """A backend for the Git smart server implementation."""
 
@@ -64,9 +70,13 @@ class Backend(object):
 class BackendRepo(object):
     """Repository abstraction used by the Git server.
     
-    Eventually this should become just a subset of Repo.
+    Please note that the methods required here are a 
+    subset of those provided by dulwich.repo.Repo.
     """
 
+    object_store = None
+    refs = None
+
     def get_refs(self):
         """
         Get all the refs in the repository
@@ -80,21 +90,12 @@ class BackendRepo(object):
 
         :param name: Name of the ref to peel
         :return: The peeled value of the ref. If the ref is known not point to
-            a tag, this will be the SHA the ref refers to. If the ref may 
-            point to a tag, but no cached information is available, None is 
-            returned.
+            a tag, this will be the SHA the ref refers to. If no cached
+            information about a tag is available, this method may return None,
+            but it should attempt to peel the tag if possible.
         """
         return None
 
-    def apply_pack(self, refs, read, delete_refs=True):
-        """ Import a set of changes into a repository and update the refs
-
-        :param refs: list of tuple(name, sha)
-        :param read: callback to read from the incoming pack
-        :param delete_refs: whether to allow deleting refs
-        """
-        raise NotImplementedError
-
     def fetch_objects(self, determine_wants, graph_walker, progress,
                       get_tagged=None):
         """
@@ -107,68 +108,30 @@ class BackendRepo(object):
         raise NotImplementedError
 
 
-class GitBackendRepo(BackendRepo):
+class PackStreamCopier(PackStreamReader):
+    """Class to verify a pack stream as it is being read.
 
-    def __init__(self, repo):
-        self.repo = repo
-        self.refs = self.repo.refs
-        self.object_store = self.repo.object_store
-        self.fetch_objects = self.repo.fetch_objects
-        self.get_refs = self.repo.get_refs
+    The pack is read from a ReceivableProtocol using read() or recv() as
+    appropriate and written out to the given file-like object.
+    """
 
-    def apply_pack(self, refs, read, delete_refs=True):
-        f, commit = self.repo.object_store.add_thin_pack()
-        all_exceptions = (IOError, OSError, ChecksumMismatch, ApplyDeltaError)
-        status = []
-        unpack_error = None
-        # TODO: more informative error messages than just the exception string
-        try:
-            # TODO: decode the pack as we stream to avoid blocking reads beyond
-            # the end of data (when using HTTP/1.1 chunked encoding)
-            while True:
-                data = read(10240)
-                if not data:
-                    break
-                f.write(data)
-        except all_exceptions, e:
-            unpack_error = str(e).replace('\n', '')
-        try:
-            commit()
-        except all_exceptions, e:
-            if not unpack_error:
-                unpack_error = str(e).replace('\n', '')
+    def __init__(self, read_all, read_some, outfile):
+        super(PackStreamCopier, self).__init__(read_all, read_some)
+        self.outfile = outfile
 
-        if unpack_error:
-            status.append(('unpack', unpack_error))
-        else:
-            status.append(('unpack', 'ok'))
+    def _read(self, read, size):
+        data = super(PackStreamCopier, self)._read(read, size)
+        self.outfile.write(data)
+        return data
 
-        for oldsha, sha, ref in refs:
-            ref_error = None
-            try:
-                if sha == ZERO_SHA:
-                    if not delete_refs:
-                        raise GitProtocolError(
-                          'Attempted to delete refs without delete-refs '
-                          'capability.')
-                    try:
-                        del self.repo.refs[ref]
-                    except all_exceptions:
-                        ref_error = 'failed to delete'
-                else:
-                    try:
-                        self.repo.refs[ref] = sha
-                    except all_exceptions:
-                        ref_error = 'failed to write'
-            except KeyError, e:
-                ref_error = 'bad ref'
-            if ref_error:
-                status.append((ref, ref_error))
-            else:
-                status.append((ref, 'ok'))
+    def verify(self):
+        """Verify a pack stream and write it to the output file.
 
-        print "pack applied"
-        return status
+        See PackStreamReader.iterobjects for a list of exceptions this may
+        throw.
+        """
+        for _, _, _ in self.read_objects():
+            pass
 
 
 class DictBackend(Backend):
@@ -185,9 +148,9 @@ class DictBackend(Backend):
 class Handler(object):
     """Smart protocol command handler base class."""
 
-    def __init__(self, backend, read, write):
+    def __init__(self, backend, proto):
         self.backend = backend
-        self.proto = Protocol(read, write)
+        self.proto = proto
         self._client_capabilities = None
 
     def capability_line(self):
@@ -226,9 +189,9 @@ class Handler(object):
 class UploadPackHandler(Handler):
     """Protocol handler for uploading a pack to the server."""
 
-    def __init__(self, backend, args, read, write,
+    def __init__(self, backend, args, proto,
                  stateless_rpc=False, advertise_refs=False):
-        Handler.__init__(self, backend, read, write)
+        Handler.__init__(self, backend, proto)
         self.repo = backend.open_repository(args[0])
         self._graph_walker = None
         self.stateless_rpc = stateless_rpc
@@ -367,10 +330,10 @@ class ProtocolGraphWalker(object):
         while command != None:
             if command != 'want':
                 raise GitProtocolError(
-                    'Protocol got unexpected command %s' % command)
+                  'Protocol got unexpected command %s' % command)
             if sha not in values:
                 raise GitProtocolError(
-                    'Client wants invalid object %s' % sha)
+                  'Client wants invalid object %s' % sha)
             want_revs.append(sha)
             command, sha = self.read_proto_line()
 
@@ -478,10 +441,10 @@ class ProtocolGraphWalker(object):
 
     def set_ack_type(self, ack_type):
         impl_classes = {
-            MULTI_ACK: MultiAckGraphWalkerImpl,
-            MULTI_ACK_DETAILED: MultiAckDetailedGraphWalkerImpl,
-            SINGLE_ACK: SingleAckGraphWalkerImpl,
-            }
+          MULTI_ACK: MultiAckGraphWalkerImpl,
+          MULTI_ACK_DETAILED: MultiAckDetailedGraphWalkerImpl,
+          SINGLE_ACK: SingleAckGraphWalkerImpl,
+          }
         self._impl = impl_classes[ack_type](self)
 
 
@@ -590,9 +553,9 @@ class MultiAckDetailedGraphWalkerImpl(object):
 class ReceivePackHandler(Handler):
     """Protocol handler for downloading a pack from the client."""
 
-    def __init__(self, backend, args, read, write,
+    def __init__(self, backend, args, proto,
                  stateless_rpc=False, advertise_refs=False):
-        Handler.__init__(self, backend, read, write)
+        Handler.__init__(self, backend, proto)
         self.repo = backend.open_repository(args[0])
         self.stateless_rpc = stateless_rpc
         self.advertise_refs = advertise_refs
@@ -600,6 +563,48 @@ class ReceivePackHandler(Handler):
     def capabilities(self):
         return ("report-status", "delete-refs")
 
+    def _apply_pack(self, refs):
+        f, commit = self.repo.object_store.add_thin_pack()
+        all_exceptions = (IOError, OSError, ChecksumMismatch, ApplyDeltaError,
+                          AssertionError, socket.error, zlib.error,
+                          ObjectFormatException)
+        status = []
+        # TODO: more informative error messages than just the exception string
+        try:
+            PackStreamCopier(self.proto.read, self.proto.recv, f).verify()
+            p = commit()
+            if not p:
+                raise IOError('Failed to write pack')
+            p.check()
+            status.append(('unpack', 'ok'))
+        except all_exceptions, e:
+            status.append(('unpack', str(e).replace('\n', '')))
+            # The pack may still have been moved in, but it may contain broken
+            # objects. We trust a later GC to clean it up.
+
+        for oldsha, sha, ref in refs:
+            ref_status = 'ok'
+            try:
+                if sha == ZERO_SHA:
+                    if not 'delete-refs' in self.capabilities():
+                        raise GitProtocolError(
+                          'Attempted to delete refs without delete-refs '
+                          'capability.')
+                    try:
+                        del self.repo.refs[ref]
+                    except all_exceptions:
+                        ref_status = 'failed to delete'
+                else:
+                    try:
+                        self.repo.refs[ref] = sha
+                    except all_exceptions:
+                        ref_status = 'failed to write'
+            except KeyError, e:
+                ref_status = 'bad ref'
+            status.append((ref, ref_status))
+
+        return status
+
     def handle(self):
         refs = self.repo.get_refs().items()
 
@@ -635,8 +640,7 @@ class ReceivePackHandler(Handler):
             ref = self.proto.read_pkt_line()
 
         # backend can now deal with this refs and read a pack using self.read
-        status = self.repo.apply_pack(client_refs, self.proto.read,
-            self.has_capability('delete-refs'))
+        status = self._apply_pack(client_refs)
 
         # when we have read all the pack from the client, send a status report
         # if the client asked for it
@@ -651,21 +655,27 @@ class ReceivePackHandler(Handler):
             self.proto.write_pkt_line(None)
 
 
+# Default handler classes for git services.
+DEFAULT_HANDLERS = {
+  'git-upload-pack': UploadPackHandler,
+  'git-receive-pack': ReceivePackHandler,
+  }
+
+
 class TCPGitRequestHandler(SocketServer.StreamRequestHandler):
 
+    def __init__(self, handlers, *args, **kwargs):
+        self.handlers = handlers and handlers or DEFAULT_HANDLERS
+        SocketServer.StreamRequestHandler.__init__(self, *args, **kwargs)
+
     def handle(self):
-        proto = Protocol(self.rfile.read, self.wfile.write)
+        proto = ReceivableProtocol(self.connection.recv, self.wfile.write)
         command, args = proto.read_cmd()
 
-        # switch case to handle the specific git command
-        if command == 'git-upload-pack':
-            cls = UploadPackHandler
-        elif command == 'git-receive-pack':
-            cls = ReceivePackHandler
-        else:
-            return
-
-        h = cls(self.server.backend, args, self.rfile.read, self.wfile.write)
+        cls = self.handlers.get(command, None)
+        if not callable(cls):
+            raise GitProtocolError('Invalid service %s' % command)
+        h = cls(self.server.backend, args, proto)
         h.handle()
 
 
@@ -674,6 +684,11 @@ class TCPGitServer(SocketServer.TCPServer):
     allow_reuse_address = True
     serve = SocketServer.TCPServer.serve_forever
 
-    def __init__(self, backend, listen_addr, port=TCP_GIT_PORT):
+    def _make_handler(self, *args, **kwargs):
+        return TCPGitRequestHandler(self.handlers, *args, **kwargs)
+
+    def __init__(self, backend, listen_addr, port=TCP_GIT_PORT, handlers=None):
         self.backend = backend
-        SocketServer.TCPServer.__init__(self, (listen_addr, port), TCPGitRequestHandler)
+        self.handlers = handlers
+        SocketServer.TCPServer.__init__(self, (listen_addr, port),
+                                        self._make_handler)