Add include-tag capability to server.
authorDave Borowitz <dborowitz@google.com>
Mon, 22 Feb 2010 18:26:02 +0000 (10:26 -0800)
committerDave Borowitz <dborowitz@google.com>
Thu, 4 Mar 2010 17:50:05 +0000 (09:50 -0800)
For this, we need to pass an include_tag flag as well as a dict of
shas that we kno have tags pointing to them all the way down to
MissingObjectFinder.

Added unit test to server.py, but still missing MissingObjectFinder
tests.

Change-Id: Ifd5e623c712842d0eb5133fb5fb78ab2c8c6970f

dulwich/object_store.py
dulwich/repo.py
dulwich/server.py
dulwich/tests/test_object_store.py
dulwich/tests/test_server.py

index 57e99934e5567d06cb98ee8d0d344606805a66ef..d8c78d7474834b014660c732f726b71f9569947d 100644 (file)
@@ -195,16 +195,20 @@ class BaseObjectStore(object):
                 else:
                     yield path, mode, hexsha
 
-    def find_missing_objects(self, haves, wants, progress=None):
+    def find_missing_objects(self, haves, wants, progress=None,
+                             get_tagged=None):
         """Find the missing objects required for a set of revisions.
 
         :param haves: Iterable over SHAs already in common.
         :param wants: Iterable over SHAs of objects to fetch.
         :param progress: Simple progress function that will be called with 
             updated progress strings.
+        :param get_tagged: Function that returns a dict of pointed-to sha -> tag
+            sha for including tags.
         :return: Iterator over (sha, path) pairs.
         """
-        return iter(MissingObjectFinder(self, haves, wants, progress).next, None)
+        finder = MissingObjectFinder(self, haves, wants, progress, get_tagged)
+        return iter(finder.next, None)
 
     def find_common_revisions(self, graphwalker):
         """Find which revisions this store has in common using graphwalker.
@@ -637,9 +641,13 @@ class MissingObjectFinder(object):
     :param haves: SHA1s of commits not to send (already present in target)
     :param wants: SHA1s of commits to send
     :param progress: Optional function to report progress to.
+    :param get_tagged: Function that returns a dict of pointed-to sha -> tag
+        sha for including tags.
+    :param tagged: dict of pointed-to sha -> tag sha for including tags
     """
 
-    def __init__(self, object_store, haves, wants, progress=None):
+    def __init__(self, object_store, haves, wants, progress=None,
+                 get_tagged=None):
         self.sha_done = set(haves)
         self.objects_to_send = set([(w, None, False) for w in wants if w not in haves])
         self.object_store = object_store
@@ -647,6 +655,7 @@ class MissingObjectFinder(object):
             self.progress = lambda x: None
         else:
             self.progress = progress
+        self._tagged = get_tagged and get_tagged() or {}
 
     def add_todo(self, entries):
         self.objects_to_send.update([e for e in entries if not e[0] in self.sha_done])
@@ -673,6 +682,8 @@ class MissingObjectFinder(object):
                 self.parse_tree(o)
             elif isinstance(o, Tag):
                 self.parse_tag(o)
+        if sha in self._tagged:
+            self.add_todo([(self._tagged[sha], None, True)])
         self.sha_done.add(sha)
         self.progress("counting objects: %d\r" % len(self.sha_done))
         return (sha, name)
index 860a1d30b8a9f9e562b7dce173b90788596fce33..520fe575cceaf732f2d875f9f550b619d608e0e2 100644 (file)
@@ -648,7 +648,8 @@ class BaseRepo(object):
                 progress))
         return self.get_refs()
 
-    def fetch_objects(self, determine_wants, graph_walker, progress):
+    def fetch_objects(self, determine_wants, graph_walker, progress,
+                      get_tagged=None):
         """Fetch the missing objects required for a set of revisions.
 
         :param determine_wants: Function that takes a dictionary with heads 
@@ -658,12 +659,15 @@ class BaseRepo(object):
             that a revision is present.
         :param progress: Simple progress function that will be called with 
             updated progress strings.
+        :param get_tagged: Function that returns a dict of pointed-to sha -> tag
+            sha for including tags.
         :return: iterator over objects, with __len__ implemented
         """
         wants = determine_wants(self.get_refs())
         haves = self.object_store.find_common_revisions(graph_walker)
         return self.object_store.iter_shas(
-            self.object_store.find_missing_objects(haves, wants, progress))
+            self.object_store.find_missing_objects(haves, wants, progress,
+                                                   get_tagged))
 
     def get_graph_walker(self, heads=None):
         if heads is None:
index 24d8de60cd111a04b1559dde45619a314df2a89e..620abc37e43b04b132b66ce2b221c393aa0288fb 100644 (file)
@@ -76,11 +76,14 @@ class Backend(object):
         """
         raise NotImplementedError
 
-    def fetch_objects(self, determine_wants, graph_walker, progress):
+    def fetch_objects(self, determine_wants, graph_walker, progress,
+                      get_tagged=None):
         """
         Yield the objects required for a list of commits.
 
         :param progress: is a callback to send progress messages to the client
+        :param get_tagged: Function that returns a dict of pointed-to sha -> tag
+            sha for including tags.
         """
         raise NotImplementedError
 
@@ -91,6 +94,7 @@ class GitBackend(Backend):
         if repo is None:
             repo = Repo(tmpfile.mkdtemp())
         self.repo = repo
+        self.refs = self.repo.refs
         self.object_store = self.repo.object_store
         self.fetch_objects = self.repo.fetch_objects
         self.get_refs = self.repo.get_refs
@@ -204,7 +208,7 @@ class UploadPackHandler(Handler):
 
     def capabilities(self):
         return ("multi_ack_detailed", "multi_ack", "side-band-64k", "thin-pack",
-                "ofs-delta", "no-progress")
+                "ofs-delta", "no-progress", "include-tag")
 
     def required_capabilities(self):
         return ("side-band-64k", "thin-pack", "ofs-delta")
@@ -214,12 +218,42 @@ class UploadPackHandler(Handler):
             return
         self.proto.write_sideband(2, message)
 
+    def get_tagged(self, refs=None, repo=None):
+        """Get a dict of peeled values of tags to their original tag shas.
+
+        :param refs: dict of refname -> sha of possible tags; defaults to all of
+            the backend's refs.
+        :param repo: optional Repo instance for getting peeled refs; defaults to
+            the backend's repo, if available
+        :return: dict of peeled_sha -> tag_sha, where tag_sha is the sha of a
+            tag whose peeled value is peeled_sha.
+        """
+        if not self.has_capability("include-tag"):
+            return {}
+        if refs is None:
+            refs = self.backend.get_refs()
+        if repo is None:
+            repo = getattr(self.backend, "repo", None)
+            if repo is None:
+                # Bail if we don't have a Repo available; this is ok since
+                # clients must be able to handle if the server doesn't include
+                # all relevant tags.
+                # TODO: either guarantee a Repo, or fix behavior when missing
+                return {}
+        tagged = {}
+        for name, sha in refs.iteritems():
+            peeled_sha = repo.get_peeled(name)
+            if peeled_sha != sha:
+                tagged[peeled_sha] = sha
+        return tagged
+
     def handle(self):
         write = lambda x: self.proto.write_sideband(1, x)
 
         graph_walker = ProtocolGraphWalker(self)
         objects_iter = self.backend.fetch_objects(
-          graph_walker.determine_wants, graph_walker, self.progress)
+          graph_walker.determine_wants, graph_walker, self.progress,
+          get_tagged=self.get_tagged)
 
         # Do they want any objects?
         if len(objects_iter) == 0:
index 3e5ca244536bc9550cb74d88e1a0a70f3b87f81c..b012658720497c8b3a747889afdc884a683fed78 100644 (file)
@@ -99,3 +99,6 @@ class DiskObjectStoreTests(ObjectStoreTests,TestCase):
             shutil.rmtree("foo")
         os.makedirs(os.path.join("foo", "pack"))
         self.store = DiskObjectStore("foo")
+
+
+# TODO: MissingObjectFinderTests
index dff22dc1015dd5a127c285960ba5b6dd447bb2b5..76d3c9790cee33e8d8e18b664a4a405d44ad32b4 100644 (file)
@@ -142,6 +142,31 @@ class UploadPackHandlerTestCase(TestCase):
         self._handler.progress('second message')
         self.assertEqual(None, self._handler.proto.get_received_line(2))
 
+    def test_get_tagged(self):
+        refs = {
+            'refs/tags/tag1': ONE,
+            'refs/tags/tag2': TWO,
+            'refs/heads/master': FOUR,  # not a tag, no peeled value
+            }
+        peeled = {
+            'refs/tags/tag1': '1234',
+            'refs/tags/tag2': '5678',
+            }
+
+        class TestRepo(object):
+            def get_peeled(self, ref):
+                return peeled.get(ref, refs[ref])
+
+        caps = list(self._handler.required_capabilities()) + ['include-tag']
+        self._handler.set_client_capabilities(caps)
+        self.assertEquals({'1234': ONE, '5678': TWO},
+                          self._handler.get_tagged(refs, repo=TestRepo()))
+
+        # non-include-tag case
+        caps = self._handler.required_capabilities()
+        self._handler.set_client_capabilities(caps)
+        self.assertEquals({}, self._handler.get_tagged(refs, repo=TestRepo()))
+
 
 class TestCommit(object):
     def __init__(self, sha, parents, commit_time):