py3k: Go through all uses of itertools and make them work on py3k
[jelmer/dulwich.git] / dulwich / object_store.py
index 26f74ca3b4c5b0aacfb32f334952632dfebc341a..ce2ec868d9998e836baf489c2a9ff7653e1e2fae 100644 (file)
@@ -21,9 +21,9 @@
 """Git object store interfaces and implementation."""
 
 
-from cStringIO import StringIO
+from io import BytesIO
 import errno
-import itertools
+from itertools import chain
 import os
 import stat
 import tempfile
@@ -162,7 +162,8 @@ class BaseObjectStore(object):
                 yield entry
 
     def find_missing_objects(self, haves, wants, progress=None,
-                             get_tagged=None):
+                             get_tagged=None,
+                             get_parents=lambda commit: commit.parents):
         """Find the missing objects required for a set of revisions.
 
         :param haves: Iterable over SHAs already in common.
@@ -171,9 +172,10 @@ class BaseObjectStore(object):
             updated progress strings.
         :param get_tagged: Function that returns a dict of pointed-to sha -> tag
             sha for including tags.
+        :param get_parents: Optional function for getting the parents of a commit.
         :return: Iterator over (sha, path) pairs.
         """
-        finder = MissingObjectFinder(self, haves, wants, progress, get_tagged)
+        finder = MissingObjectFinder(self, haves, wants, progress, get_tagged, get_parents=get_parents)
         return iter(finder.next, None)
 
     def find_common_revisions(self, graphwalker):
@@ -183,12 +185,12 @@ class BaseObjectStore(object):
         :return: List of SHAs that are in common
         """
         haves = []
-        sha = graphwalker.next()
+        sha = next(graphwalker)
         while sha:
             if sha in self:
                 haves.append(sha)
                 graphwalker.ack(sha)
-            sha = graphwalker.next()
+            sha = next(graphwalker)
         return haves
 
     def generate_pack_contents(self, have, want, progress=None):
@@ -215,12 +217,14 @@ class BaseObjectStore(object):
             obj = self[sha]
         return obj
 
-    def _collect_ancestors(self, heads, common=set()):
+    def _collect_ancestors(self, heads, common=set(),
+                           get_parents=lambda commit: commit.parents):
         """Collect all ancestors of heads up to (excluding) those in common.
 
         :param heads: commits to start from
         :param common: commits to end at, or empty set to walk repository
             completely
+        :param get_parents: Optional function for getting the parents of a commit.
         :return: a tuple (A, B) where A - all commits reachable
             from heads but not present in common, B - common (shared) elements
             that are directly reachable from heads
@@ -236,7 +240,7 @@ class BaseObjectStore(object):
             elif e not in commits:
                 commits.add(e)
                 cmt = self[e]
-                queue.extend(cmt.parents)
+                queue.extend(get_parents(cmt))
         return (commits, bases)
 
     def close(self):
@@ -247,7 +251,7 @@ class BaseObjectStore(object):
 class PackBasedObjectStore(BaseObjectStore):
 
     def __init__(self):
-        self._pack_cache = None
+        self._pack_cache = {}
 
     @property
     def alternates(self):
@@ -275,33 +279,30 @@ class PackBasedObjectStore(BaseObjectStore):
                 return True
         return False
 
-    def _load_packs(self):
-        raise NotImplementedError(self._load_packs)
-
     def _pack_cache_stale(self):
         """Check whether the pack cache is stale."""
         raise NotImplementedError(self._pack_cache_stale)
 
-    def _add_known_pack(self, pack):
+    def _add_known_pack(self, base_name, pack):
         """Add a newly appeared pack to the cache by path.
 
         """
-        if self._pack_cache is not None:
-            self._pack_cache.append(pack)
+        self._pack_cache[base_name] = pack
 
     def close(self):
         pack_cache = self._pack_cache
-        self._pack_cache = None
+        self._pack_cache = {}
         while pack_cache:
-            pack = pack_cache.pop()
+            (name, pack) = pack_cache.popitem()
             pack.close()
 
     @property
     def packs(self):
         """List with pack objects."""
         if self._pack_cache is None or self._pack_cache_stale():
-            self._pack_cache = self._load_packs()
-        return self._pack_cache
+            self._update_pack_cache()
+
+        return self._pack_cache.values()
 
     def _iter_alternate_objects(self):
         """Iterate over the SHAs of all the objects in alternate stores."""
@@ -335,7 +336,7 @@ class PackBasedObjectStore(BaseObjectStore):
     def __iter__(self):
         """Iterate over the SHAs that are present in this store."""
         iterables = self.packs + [self._iter_loose_objects()] + [self._iter_alternate_objects()]
-        return itertools.chain(*iterables)
+        return chain(*iterables)
 
     def contains_loose(self, sha):
         """Check if a particular object is present by SHA1 and is loose.
@@ -406,8 +407,12 @@ class DiskObjectStore(PackBasedObjectStore):
         self.path = path
         self.pack_dir = os.path.join(self.path, PACKDIR)
         self._pack_cache_time = 0
+        self._pack_cache = {}
         self._alternates = None
 
+    def __repr__(self):
+        return "<%s(%r)>" % (self.__class__.__name__, self.path)
+
     @property
     def alternates(self):
         if self._alternates is not None:
@@ -421,7 +426,7 @@ class DiskObjectStore(PackBasedObjectStore):
         try:
             f = GitFile(os.path.join(self.path, "info", "alternates"),
                     'rb')
-        except (OSError, IOError), e:
+        except (OSError, IOError) as e:
             if e.errno == errno.ENOENT:
                 return []
             raise
@@ -444,7 +449,7 @@ class DiskObjectStore(PackBasedObjectStore):
         """
         try:
             os.mkdir(os.path.join(self.path, "info"))
-        except OSError, e:
+        except OSError as e:
             if e.errno != errno.EEXIST:
                 raise
         alternates_path = os.path.join(self.path, "info/alternates")
@@ -452,7 +457,7 @@ class DiskObjectStore(PackBasedObjectStore):
         try:
             try:
                 orig_f = open(alternates_path, 'rb')
-            except (OSError, IOError), e:
+            except (OSError, IOError) as e:
                 if e.errno != errno.ENOENT:
                     raise
             else:
@@ -468,36 +473,34 @@ class DiskObjectStore(PackBasedObjectStore):
             path = os.path.join(self.path, path)
         self.alternates.append(DiskObjectStore(path))
 
-    def _load_packs(self):
-        pack_files = []
+    def _update_pack_cache(self):
         try:
-            self._pack_cache_time = os.stat(self.pack_dir).st_mtime
             pack_dir_contents = os.listdir(self.pack_dir)
-            for name in pack_dir_contents:
-                # TODO: verify that idx exists first
-                if name.startswith("pack-") and name.endswith(".pack"):
-                    filename = os.path.join(self.pack_dir, name)
-                    pack_files.append((os.stat(filename).st_mtime, filename))
-        except OSError, e:
+        except OSError as e:
             if e.errno == errno.ENOENT:
-                return []
+                self._pack_cache_time = 0
+                self.close()
+                return
             raise
-        pack_files.sort(reverse=True)
-        suffix_len = len(".pack")
-        result = []
-        try:
-            for _, f in pack_files:
-                result.append(Pack(f[:-suffix_len]))
-        except:
-            for p in result:
-                p.close()
-            raise
-        return result
+        self._pack_cache_time = os.stat(self.pack_dir).st_mtime
+        pack_files = set()
+        for name in pack_dir_contents:
+            # TODO: verify that idx exists first
+            if name.startswith("pack-") and name.endswith(".pack"):
+                pack_files.add(name[:-len(".pack")])
+
+        # Open newly appeared pack files
+        for f in pack_files:
+            if f not in self._pack_cache:
+                self._pack_cache[f] = Pack(os.path.join(self.pack_dir, f))
+        # Remove disappeared pack files
+        for f in set(self._pack_cache) - pack_files:
+            self._pack_cache.pop(f).close()
 
     def _pack_cache_stale(self):
         try:
             return os.stat(self.pack_dir).st_mtime > self._pack_cache_time
-        except OSError, e:
+        except OSError as e:
             if e.errno == errno.ENOENT:
                 return True
             raise
@@ -517,7 +520,7 @@ class DiskObjectStore(PackBasedObjectStore):
         path = self._get_shafile_path(sha)
         try:
             return ShaFile.from_path(path)
-        except (OSError, IOError), e:
+        except (OSError, IOError) as e:
             if e.errno == errno.ENOENT:
                 return None
             raise
@@ -579,7 +582,7 @@ class DiskObjectStore(PackBasedObjectStore):
         # Add the pack to the store and return it.
         final_pack = Pack(pack_base_name)
         final_pack.check_length_and_checksum()
-        self._add_known_pack(final_pack)
+        self._add_known_pack(pack_base_name, final_pack)
         return final_pack
 
     def add_thin_pack(self, read_all, read_some):
@@ -630,7 +633,7 @@ class DiskObjectStore(PackBasedObjectStore):
             p.close()
         os.rename(path, basename + ".pack")
         final_pack = Pack(basename)
-        self._add_known_pack(final_pack)
+        self._add_known_pack(basename, final_pack)
         return final_pack
 
     def add_pack(self):
@@ -663,7 +666,7 @@ class DiskObjectStore(PackBasedObjectStore):
         dir = os.path.join(self.path, obj.id[:2])
         try:
             os.mkdir(dir)
-        except OSError, e:
+        except OSError as e:
             if e.errno != errno.EEXIST:
                 raise
         path = os.path.join(dir, obj.id[2:])
@@ -679,7 +682,7 @@ class DiskObjectStore(PackBasedObjectStore):
     def init(cls, path):
         try:
             os.mkdir(path)
-        except OSError, e:
+        except OSError as e:
             if e.errno != errno.EEXIST:
                 raise
         os.mkdir(os.path.join(path, "info"))
@@ -758,9 +761,9 @@ class MemoryObjectStore(BaseObjectStore):
         :return: Fileobject to write to and a commit function to
             call when the pack is finished.
         """
-        f = StringIO()
+        f = BytesIO()
         def commit():
-            p = PackData.from_file(StringIO(f.getvalue()), f.tell())
+            p = PackData.from_file(BytesIO(f.getvalue()), f.tell())
             f.close()
             for obj in PackInflater.for_pack_data(p):
                 self._data[obj.id] = obj
@@ -970,12 +973,14 @@ class MissingObjectFinder(object):
     :param progress: Optional function to report progress to.
     :param get_tagged: Function that returns a dict of pointed-to sha -> tag
         sha for including tags.
+    :param get_parents: Optional function for getting the parents of a commit.
     :param tagged: dict of pointed-to sha -> tag sha for including tags
     """
 
     def __init__(self, object_store, haves, wants, progress=None,
-                 get_tagged=None):
+            get_tagged=None, get_parents=lambda commit: commit.parents):
         self.object_store = object_store
+        self._get_parents = get_parents
         # process Commits and Tags differently
         # Note, while haves may list commits/tags not available locally,
         # and such SHAs would get filtered out by _split_commits_and_tags,
@@ -987,12 +992,16 @@ class MissingObjectFinder(object):
                 _split_commits_and_tags(object_store, wants, False)
         # all_ancestors is a set of commits that shall not be sent
         # (complete repository up to 'haves')
-        all_ancestors = object_store._collect_ancestors(have_commits)[0]
+        all_ancestors = object_store._collect_ancestors(
+                have_commits,
+                get_parents=self._get_parents)[0]
         # all_missing - complete set of commits between haves and wants
         # common - commits from all_ancestors we hit into while
         # traversing parent hierarchy of wants
-        missing_commits, common_commits = \
-            object_store._collect_ancestors(want_commits, all_ancestors)
+        missing_commits, common_commits = object_store._collect_ancestors(
+            want_commits,
+            all_ancestors,
+            get_parents=self._get_parents);
         self.sha_done = set()
         # Now, fill sha_done with commits and revisions of
         # files and directories known to be both locally
@@ -1046,6 +1055,8 @@ class MissingObjectFinder(object):
         self.progress("counting objects: %d\r" % len(self.sha_done))
         return (sha, name)
 
+    __next__ = next
+
 
 class ObjectStoreGraphWalker(object):
     """Graph walker that finds what commits are missing from an object store.
@@ -1097,3 +1108,5 @@ class ObjectStoreGraphWalker(object):
             self.heads.update([p for p in ps if not p in self.parents])
             return ret
         return None
+
+    __next__ = next