Merge Dave's fixes for the compatibility tests and web.
[jelmer/dulwich-libgit2.git] / dulwich / server.py
index 2e19838d402790a49b86682b7a1525ac075b2b27..30f0c0873bb1cd80a4a9f7ae1b8beb1827fb0c6a 100644 (file)
@@ -28,7 +28,6 @@ Documentation/technical directory in the cgit distribution, and in particular:
 
 import collections
 import SocketServer
-import tempfile
 
 from dulwich.errors import (
     ApplyDeltaError,
@@ -42,6 +41,7 @@ from dulwich.protocol import (
     Protocol,
     ProtocolFile,
     TCP_GIT_PORT,
+    ZERO_SHA,
     extract_capabilities,
     extract_want_line_capabilities,
     SINGLE_ACK,
@@ -49,14 +49,27 @@ from dulwich.protocol import (
     MULTI_ACK_DETAILED,
     ack_type,
     )
-from dulwich.repo import (
-    Repo,
-    )
 from dulwich.pack import (
     write_pack_data,
     )
 
 class Backend(object):
+    """A backend for the Git smart server implementation."""
+
+    def open_repository(self, path):
+        """Open the repository at a path."""
+        raise NotImplementedError(self.open_repository)
+
+
+class BackendRepo(object):
+    """Repository abstraction used by the Git server.
+    
+    Please note that the methods required here are a 
+    subset of those provided by dulwich.repo.Repo.
+    """
+
+    object_store = None
+    refs = None
 
     def get_refs(self):
         """
@@ -66,84 +79,38 @@ class Backend(object):
         """
         raise NotImplementedError
 
-    def apply_pack(self, refs, read):
-        """ Import a set of changes into a repository and update the refs
+    def get_peeled(self, name):
+        """Return the cached peeled value of a ref, if available.
 
-        :param refs: list of tuple(name, sha)
-        :param read: callback to read from the incoming pack
+        :param name: Name of the ref to peel
+        :return: The peeled value of the ref. If the ref is known not point to
+            a tag, this will be the SHA the ref refers to. If no cached
+            information about a tag is available, this method may return None,
+            but it should attempt to peel the tag if possible.
         """
-        raise NotImplementedError
+        return None
 
-    def fetch_objects(self, determine_wants, graph_walker, progress):
+    def fetch_objects(self, determine_wants, graph_walker, progress,
+                      get_tagged=None):
         """
         Yield the objects required for a list of commits.
 
         :param progress: is a callback to send progress messages to the client
+        :param get_tagged: Function that returns a dict of pointed-to sha -> tag
+            sha for including tags.
         """
         raise NotImplementedError
 
 
-class GitBackend(Backend):
+class DictBackend(Backend):
+    """Trivial backend that looks up Git repositories in a dictionary."""
 
-    def __init__(self, repo=None):
-        if repo is None:
-            repo = Repo(tmpfile.mkdtemp())
-        self.repo = repo
-        self.object_store = self.repo.object_store
-        self.fetch_objects = self.repo.fetch_objects
-        self.get_refs = self.repo.get_refs
+    def __init__(self, repos):
+        self.repos = repos
 
-    def apply_pack(self, refs, read):
-        f, commit = self.repo.object_store.add_thin_pack()
-        all_exceptions = (IOError, OSError, ChecksumMismatch, ApplyDeltaError)
-        status = []
-        unpack_error = None
-        # TODO: more informative error messages than just the exception string
-        try:
-            # TODO: decode the pack as we stream to avoid blocking reads beyond
-            # the end of data (when using HTTP/1.1 chunked encoding)
-            while True:
-                data = read(10240)
-                if not data:
-                    break
-                f.write(data)
-        except all_exceptions, e:
-            unpack_error = str(e).replace('\n', '')
-        try:
-            commit()
-        except all_exceptions, e:
-            if not unpack_error:
-                unpack_error = str(e).replace('\n', '')
-
-        if unpack_error:
-            status.append(('unpack', unpack_error))
-        else:
-            status.append(('unpack', 'ok'))
-
-        for oldsha, sha, ref in refs:
-            # TODO: check refname
-            ref_error = None
-            try:
-                if ref == "0" * 40:
-                    try:
-                        del self.repo.refs[ref]
-                    except all_exceptions:
-                        ref_error = 'failed to delete'
-                else:
-                    try:
-                        self.repo.refs[ref] = sha
-                    except all_exceptions:
-                        ref_error = 'failed to write'
-            except KeyError, e:
-                ref_error = 'bad ref'
-            if ref_error:
-                status.append((ref, ref_error))
-            else:
-                status.append((ref, 'ok'))
-
-
-        print "pack applied"
-        return status
+    def open_repository(self, path):
+        # FIXME: What to do in case there is no repo ?
+        return self.repos[path]
 
 
 class Handler(object):
@@ -152,58 +119,111 @@ class Handler(object):
     def __init__(self, backend, read, write):
         self.backend = backend
         self.proto = Protocol(read, write)
+        self._client_capabilities = None
+
+    def capability_line(self):
+        return " ".join(self.capabilities())
 
     def capabilities(self):
-        return " ".join(self.default_capabilities())
+        raise NotImplementedError(self.capabilities)
+
+    def innocuous_capabilities(self):
+        return ("include-tag", "thin-pack", "no-progress", "ofs-delta")
+
+    def required_capabilities(self):
+        """Return a list of capabilities that we require the client to have."""
+        return []
+
+    def set_client_capabilities(self, caps):
+        allowable_caps = set(self.innocuous_capabilities())
+        allowable_caps.update(self.capabilities())
+        for cap in caps:
+            if cap not in allowable_caps:
+                raise GitProtocolError('Client asked for capability %s that '
+                                       'was not advertised.' % cap)
+        for cap in self.required_capabilities():
+            if cap not in caps:
+                raise GitProtocolError('Client does not support required '
+                                       'capability %s.' % cap)
+        self._client_capabilities = set(caps)
+
+    def has_capability(self, cap):
+        if self._client_capabilities is None:
+            raise GitProtocolError('Server attempted to access capability %s '
+                                   'before asking client' % cap)
+        return cap in self._client_capabilities
 
 
 class UploadPackHandler(Handler):
     """Protocol handler for uploading a pack to the server."""
 
-    def __init__(self, backend, read, write,
+    def __init__(self, backend, args, read, write,
                  stateless_rpc=False, advertise_refs=False):
         Handler.__init__(self, backend, read, write)
-        self._client_capabilities = None
+        self.repo = backend.open_repository(args[0])
         self._graph_walker = None
         self.stateless_rpc = stateless_rpc
         self.advertise_refs = advertise_refs
 
-    def default_capabilities(self):
+    def capabilities(self):
         return ("multi_ack_detailed", "multi_ack", "side-band-64k", "thin-pack",
-                "ofs-delta")
+                "ofs-delta", "no-progress", "include-tag")
 
-    def set_client_capabilities(self, caps):
-        my_caps = self.default_capabilities()
-        for cap in caps:
-            if '_ack' in cap and cap not in my_caps:
-                raise GitProtocolError('Client asked for capability %s that '
-                                       'was not advertised.' % cap)
-        self._client_capabilities = caps
+    def required_capabilities(self):
+        return ("side-band-64k", "thin-pack", "ofs-delta")
+
+    def progress(self, message):
+        if self.has_capability("no-progress"):
+            return
+        self.proto.write_sideband(2, message)
 
-    def get_client_capabilities(self):
-        return self._client_capabilities
+    def get_tagged(self, refs=None, repo=None):
+        """Get a dict of peeled values of tags to their original tag shas.
 
-    client_capabilities = property(get_client_capabilities,
-                                   set_client_capabilities)
+        :param refs: dict of refname -> sha of possible tags; defaults to all of
+            the backend's refs.
+        :param repo: optional Repo instance for getting peeled refs; defaults to
+            the backend's repo, if available
+        :return: dict of peeled_sha -> tag_sha, where tag_sha is the sha of a
+            tag whose peeled value is peeled_sha.
+        """
+        if not self.has_capability("include-tag"):
+            return {}
+        if refs is None:
+            refs = self.repo.get_refs()
+        if repo is None:
+            repo = getattr(self.repo, "repo", None)
+            if repo is None:
+                # Bail if we don't have a Repo available; this is ok since
+                # clients must be able to handle if the server doesn't include
+                # all relevant tags.
+                # TODO: fix behavior when missing
+                return {}
+        tagged = {}
+        for name, sha in refs.iteritems():
+            peeled_sha = repo.get_peeled(name)
+            if peeled_sha != sha:
+                tagged[peeled_sha] = sha
+        return tagged
 
     def handle(self):
-
-        progress = lambda x: self.proto.write_sideband(2, x)
         write = lambda x: self.proto.write_sideband(1, x)
 
-        graph_walker = ProtocolGraphWalker(self)
-        objects_iter = self.backend.fetch_objects(
-          graph_walker.determine_wants, graph_walker, progress)
+        graph_walker = ProtocolGraphWalker(self, self.repo.object_store,
+            self.repo.get_peeled)
+        objects_iter = self.repo.fetch_objects(
+          graph_walker.determine_wants, graph_walker, self.progress,
+          get_tagged=self.get_tagged)
 
         # Do they want any objects?
         if len(objects_iter) == 0:
             return
 
-        progress("dul-daemon says what\n")
-        progress("counting objects: %d, done.\n" % len(objects_iter))
+        self.progress("dul-daemon says what\n")
+        self.progress("counting objects: %d, done.\n" % len(objects_iter))
         write_pack_data(ProtocolFile(None, write), objects_iter, 
                         len(objects_iter))
-        progress("how was that, then?\n")
+        self.progress("how was that, then?\n")
         # we are done
         self.proto.write("0000")
 
@@ -211,9 +231,9 @@ class UploadPackHandler(Handler):
 class ProtocolGraphWalker(object):
     """A graph walker that knows the git protocol.
 
-    As a graph walker, this class implements ack(), next(), and reset(). It also
-    contains some base methods for interacting with the wire and walking the
-    commit tree.
+    As a graph walker, this class implements ack(), next(), and reset(). It
+    also contains some base methods for interacting with the wire and walking
+    the commit tree.
 
     The work of determining which acks to send is passed on to the
     implementation instance stored in _impl. The reason for this is that we do
@@ -221,9 +241,10 @@ class ProtocolGraphWalker(object):
     call to set_ack_level() is required to set up the implementation, before any
     calls to next() or ack() are made.
     """
-    def __init__(self, handler):
+    def __init__(self, handler, object_store, get_peeled):
         self.handler = handler
-        self.store = handler.backend.object_store
+        self.store = object_store
+        self.get_peeled = get_peeled
         self.proto = handler.proto
         self.stateless_rpc = handler.stateless_rpc
         self.advertise_refs = handler.advertise_refs
@@ -251,9 +272,12 @@ class ProtocolGraphWalker(object):
             for i, (ref, sha) in enumerate(heads.iteritems()):
                 line = "%s %s" % (sha, ref)
                 if not i:
-                    line = "%s\x00%s" % (line, self.handler.capabilities())
+                    line = "%s\x00%s" % (line, self.handler.capability_line())
                 self.proto.write_pkt_line("%s\n" % line)
-                # TODO: include peeled value of any tags
+                peeled_sha = self.get_peeled(ref)
+                if peeled_sha != sha:
+                    self.proto.write_pkt_line('%s %s^{}\n' %
+                                              (peeled_sha, ref))
 
             # i'm done..
             self.proto.write_pkt_line(None)
@@ -266,7 +290,7 @@ class ProtocolGraphWalker(object):
         if not want:
             return []
         line, caps = extract_want_line_capabilities(want)
-        self.handler.client_capabilities = caps
+        self.handler.set_client_capabilities(caps)
         self.set_ack_type(ack_type(caps))
         command, sha = self._split_proto_line(line)
 
@@ -359,10 +383,10 @@ class ProtocolGraphWalker(object):
             commit = pending.popleft()
             if commit.id in haves:
                 return True
-            if not getattr(commit, 'get_parents', None):
+            if commit.type_name != "commit":
                 # non-commit wants are assumed to be satisfied
                 continue
-            for parent in commit.get_parents():
+            for parent in commit.parents:
                 parent_obj = self.store[parent]
                 # TODO: handle parents with later commit times than children
                 if parent_obj.commit_time >= earliest:
@@ -497,32 +521,84 @@ class MultiAckDetailedGraphWalkerImpl(object):
 class ReceivePackHandler(Handler):
     """Protocol handler for downloading a pack from the client."""
 
-    def __init__(self, backend, read, write,
+    def __init__(self, backend, args, read, write,
                  stateless_rpc=False, advertise_refs=False):
         Handler.__init__(self, backend, read, write)
+        self.repo = backend.open_repository(args[0])
         self.stateless_rpc = stateless_rpc
         self.advertise_refs = advertise_refs
 
-    def __init__(self, backend, read, write,
-                 stateless_rpc=False, advertise_refs=False):
-        Handler.__init__(self, backend, read, write)
-        self._stateless_rpc = stateless_rpc
-        self._advertise_refs = advertise_refs
-
-    def default_capabilities(self):
+    def capabilities(self):
         return ("report-status", "delete-refs")
 
+    def _apply_pack(self, refs, read):
+        f, commit = self.repo.object_store.add_thin_pack()
+        all_exceptions = (IOError, OSError, ChecksumMismatch, ApplyDeltaError)
+        status = []
+        unpack_error = None
+        # TODO: more informative error messages than just the exception string
+        try:
+            # TODO: decode the pack as we stream to avoid blocking reads beyond
+            # the end of data (when using HTTP/1.1 chunked encoding)
+            while True:
+                data = read(10240)
+                if not data:
+                    break
+                f.write(data)
+        except all_exceptions, e:
+            unpack_error = str(e).replace('\n', '')
+        try:
+            commit()
+        except all_exceptions, e:
+            if not unpack_error:
+                unpack_error = str(e).replace('\n', '')
+
+        if unpack_error:
+            status.append(('unpack', unpack_error))
+        else:
+            status.append(('unpack', 'ok'))
+
+        for oldsha, sha, ref in refs:
+            ref_error = None
+            try:
+                if sha == ZERO_SHA:
+                    if not self.has_capability('delete-refs'):
+                        raise GitProtocolError(
+                          'Attempted to delete refs without delete-refs '
+                          'capability.')
+                    try:
+                        del self.repo.refs[ref]
+                    except all_exceptions:
+                        ref_error = 'failed to delete'
+                else:
+                    try:
+                        self.repo.refs[ref] = sha
+                    except all_exceptions:
+                        ref_error = 'failed to write'
+            except KeyError, e:
+                ref_error = 'bad ref'
+            if ref_error:
+                status.append((ref, ref_error))
+            else:
+                status.append((ref, 'ok'))
+
+        print "pack applied"
+        return status
+
     def handle(self):
-        refs = self.backend.get_refs().items()
+        refs = self.repo.get_refs().items()
 
         if self.advertise_refs or not self.stateless_rpc:
             if refs:
-                self.proto.write_pkt_line("%s %s\x00%s\n" % (refs[0][1], refs[0][0], self.capabilities()))
+                self.proto.write_pkt_line(
+                  "%s %s\x00%s\n" % (refs[0][1], refs[0][0],
+                                     self.capability_line()))
                 for i in range(1, len(refs)):
                     ref = refs[i]
                     self.proto.write_pkt_line("%s %s\n" % (ref[1], ref[0]))
             else:
-                self.proto.write_pkt_line("0000000000000000000000000000000000000000 capabilities^{} %s" % self.capabilities())
+                self.proto.write_pkt_line("%s capabilities^{} %s" % (
+                  ZERO_SHA, self.capability_line()))
 
             self.proto.write("0000")
             if self.advertise_refs:
@@ -535,7 +611,8 @@ class ReceivePackHandler(Handler):
         if ref is None:
             return
 
-        ref, client_capabilities = extract_capabilities(ref)
+        ref, caps = extract_capabilities(ref)
+        self.set_client_capabilities(caps)
 
         # client will now send us a list of (oldsha, newsha, ref)
         while ref:
@@ -543,11 +620,11 @@ class ReceivePackHandler(Handler):
             ref = self.proto.read_pkt_line()
 
         # backend can now deal with this refs and read a pack using self.read
-        status = self.backend.apply_pack(client_refs, self.proto.read)
+        status = self.repo._apply_pack(client_refs, self.proto.read)
 
         # when we have read all the pack from the client, send a status report
         # if the client asked for it
-        if 'report-status' in client_capabilities:
+        if self.has_capability('report-status'):
             for name, msg in status:
                 if name == 'unpack':
                     self.proto.write_pkt_line('unpack %s\n' % msg)
@@ -572,7 +649,7 @@ class TCPGitRequestHandler(SocketServer.StreamRequestHandler):
         else:
             return
 
-        h = cls(self.server.backend, self.rfile.read, self.wfile.write)
+        h = cls(self.server.backend, args, self.rfile.read, self.wfile.write)
         h.handle()