object_store: Include subtrees in iteration.
authorDave Borowitz <dborowitz@google.com>
Fri, 30 Jul 2010 11:09:12 +0000 (13:09 +0200)
committerJelmer Vernooij <jelmer@samba.org>
Fri, 30 Jul 2010 11:09:12 +0000 (13:09 +0200)
NEWS
dulwich/object_store.py
dulwich/tests/test_object_store.py

diff --git a/NEWS b/NEWS
index 1955d9c21aff45e0a8784ac4ee673534b8529131..89bb03de0feae883135cdd3ee1f009fc87bb13a0 100644 (file)
--- a/NEWS
+++ b/NEWS
@@ -22,6 +22,9 @@
   * ObjectStore.iter_tree_contents now walks contents in depth-first, sorted
     order. (Dave Borowitz)
 
+  * ObjectStore.iter_tree_contents can optionally yield tree objects as well.
+    (Dave Borowitz).
+
 
 0.6.1  2010-07-22
 
index 1a834399edfb8ef28dd77ff7026e074560132f67..725504f762fa4daa7e0607c78455db7048f4c670 100644 (file)
@@ -175,24 +175,26 @@ class BaseObjectStore(object):
                     else:
                         todo.add((None, newhexsha, childpath))
 
-    def iter_tree_contents(self, tree_id):
+    def iter_tree_contents(self, tree_id, include_trees=False):
         """Iterate the contents of a tree and all subtrees.
 
-        Iteration is depth-first, as in e.g. os.walk.
+        Iteration is depth-first pre-order, as in e.g. os.walk.
 
         :param tree_id: SHA1 of the tree.
+        :param include_trees: If True, include tree objects in the iteration.
         :yield: Tuples of (path, mode, hexhsa) for objects in a tree.
         """
         todo = [('', stat.S_IFDIR, tree_id)]
         while todo:
             path, mode, hexsha = todo.pop()
-            if stat.S_ISDIR(mode):
+            is_subtree = stat.S_ISDIR(mode)
+            if not is_subtree or include_trees:
+                yield path, mode, hexsha
+            if is_subtree:
                 entries = reversed(self[hexsha].iteritems())
                 for name, entry_mode, entry_hexsha in entries:
                     entry_path = posixpath.join(path, name)
                     todo.append((entry_path, entry_mode, entry_hexsha))
-            else:
-                yield path, mode, hexsha
 
     def find_missing_objects(self, haves, wants, progress=None,
                              get_tagged=None):
index e53a612dd7b8ed9c85a9348655af9755c91f23d8..e34e75f3cb43b3f574cd2109a31d6b0e64aa9d78 100644 (file)
@@ -96,6 +96,34 @@ class ObjectStoreTests(object):
         self.assertEquals([(p, m, h) for (p, h, m) in blobs],
                           list(self.store.iter_tree_contents(tree_id)))
 
+    def test_iter_tree_contents_include_trees(self):
+        blob_a = make_object(Blob, data='a')
+        blob_b = make_object(Blob, data='b')
+        blob_c = make_object(Blob, data='c')
+        for blob in [blob_a, blob_b, blob_c]:
+            self.store.add_object(blob)
+
+        blobs = [
+          ('a', blob_a.id, 0100644),
+          ('ad/b', blob_b.id, 0100644),
+          ('ad/bd/c', blob_c.id, 0100755),
+          ]
+        tree_id = commit_tree(self.store, blobs)
+        tree = self.store[tree_id]
+        tree_ad = self.store[tree['ad'][1]]
+        tree_bd = self.store[tree_ad['bd'][1]]
+
+        expected = [
+          ('', 0040000, tree_id),
+          ('a', 0100644, blob_a.id),
+          ('ad', 0040000, tree_ad.id),
+          ('ad/b', 0100644, blob_b.id),
+          ('ad/bd', 0040000, tree_bd.id),
+          ('ad/bd/c', 0100755, blob_c.id),
+          ]
+        actual = self.store.iter_tree_contents(tree_id, include_trees=True)
+        self.assertEquals(expected, list(actual))
+
 
 class MemoryObjectStoreTests(ObjectStoreTests, TestCase):