Don't try and send objects if there are none to send (client hung up already)
[jelmer/dulwich-libgit2.git] / dulwich / repo.py
index 125de877b55f58ddef09e2e5d4560d11d5db5724..3d5412b644358fd998cf2b69210f5421212352d1 100644 (file)
@@ -5,7 +5,8 @@
 # This program is free software; you can redistribute it and/or
 # modify it under the terms of the GNU General Public License
 # as published by the Free Software Foundation; version 2
-# of the License.
+# of the License or (at your option) any later version of 
+# the License.
 # 
 # This program is distributed in the hope that it will be useful,
 # but WITHOUT ANY WARRANTY; without even the implied warranty of
 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
 # MA  02110-1301, USA.
 
-import os
+import os, stat
 
 from commit import Commit
-from errors import MissingCommitError
-from objects import (ShaFile,
-                     Commit,
-                     Tree,
-                     Blob,
-                     )
-from pack import load_packs, iter_sha1, PackData, write_pack_index_v2
-import tempfile
+from errors import (
+        MissingCommitError, 
+        NotBlobError, 
+        NotCommitError, 
+        NotGitRepository,
+        NotTreeError, 
+        )
+from object_store import ObjectStore
+from objects import (
+        ShaFile,
+        Commit,
+        Tree,
+        Blob,
+        )
 
 OBJECTDIR = 'objects'
-PACKDIR = 'pack'
 SYMREF = 'ref: '
 
 
-class Tag(object):
+class Tags(object):
 
-    def __init__(self, name, ref):
-        self.name = name
-        self.ref = ref
+    def __init__(self, tagdir, tags):
+        self.tagdir = tagdir
+        self.tags = tags
+
+    def __getitem__(self, name):
+        return self.tags[name]
+    
+    def __setitem__(self, name, ref):
+        self.tags[name] = ref
+        f = open(os.path.join(self.tagdir, name), 'wb')
+        try:
+            f.write("%s\n" % ref)
+        finally:
+            f.close()
+
+    def __len__(self):
+        return len(self.tags)
+
+    def iteritems(self):
+        for k in self.tags:
+            yield k, self[k]
 
 
 class Repo(object):
@@ -46,45 +70,94 @@ class Repo(object):
   ref_locs = ['', 'refs', 'refs/tags', 'refs/heads', 'refs/remotes']
 
   def __init__(self, root):
-    controldir = os.path.join(root, ".git")
-    if os.path.exists(os.path.join(controldir, "objects")):
+    if os.path.isdir(os.path.join(root, ".git", "objects")):
       self.bare = False
-      self._basedir = controldir
-    else:
+      self._controldir = os.path.join(root, ".git")
+    elif os.path.isdir(os.path.join(root, "objects")):
       self.bare = True
-      self._basedir = root
-    self.path = controldir
-    self.tags = [Tag(name, ref) for name, ref in self.get_tags().items()]
-    self._packs = None
-
-  def basedir(self):
-    return self._basedir
+      self._controldir = root
+    else:
+      raise NotGitRepository(root)
+    self.path = root
+    self.tags = Tags(self.tagdir(), self.get_tags())
+    self._object_store = None
+
+  def controldir(self):
+    return self._controldir
+
+  def find_missing_objects(self, determine_wants, graph_walker, progress):
+    """Fetch the missing objects required for a set of revisions.
+
+    :param determine_wants: Function that takes a dictionary with heads 
+        and returns the list of heads to fetch.
+    :param graph_walker: Object that can iterate over the list of revisions 
+        to fetch and has an "ack" method that will be called to acknowledge 
+        that a revision is present.
+    :param progress: Simple progress function that will be called with 
+        updated progress strings.
+    """
+    wants = determine_wants(self.get_refs())
+    commits_to_send = set(wants)
+    sha_done = set()
+    ref = graph_walker.next()
+    while ref:
+        if ref in self.object_store:
+            graph_walker.ack(ref)
+        ref = graph_walker.next()
+    while commits_to_send:
+        sha = (commits_to_send.pop(), None)
+        if sha in sha_done:
+            continue
+
+        c = self.commit(sha)
+        assert isinstance(c, Commit)
+        sha_done.add((sha, None))
+
+        commits_to_send.update([p for p in c.parents if not p in sha_done])
+
+        def parse_tree(tree, sha_done):
+            for mode, name, sha in tree.entries():
+                if (sha, name) in sha_done:
+                    continue
+                if mode & stat.S_IFDIR:
+                    parse_tree(self.tree(sha), sha_done)
+                sha_done.add((sha, name))
+
+        treesha = c.tree
+        if c.tree not in sha_done:
+            parse_tree(self.tree(c.tree), sha_done)
+            sha_done.add((c.tree, None))
+
+        progress("counting objects: %d\r" % len(sha_done))
+    return sha_done
+
+  def fetch_objects(self, determine_wants, graph_walker, progress):
+    """Fetch the missing objects required for a set of revisions.
+
+    :param determine_wants: Function that takes a dictionary with heads 
+        and returns the list of heads to fetch.
+    :param graph_walker: Object that can iterate over the list of revisions 
+        to fetch and has an "ack" method that will be called to acknowledge 
+        that a revision is present.
+    :param progress: Simple progress function that will be called with 
+        updated progress strings.
+    :return: tuple with number of objects, iterator over objects
+    """
+    shas = self.find_missing_objects(determine_wants, graph_walker, progress)
+    return (len(shas), ((self.get_object(sha), path) for sha, path in shas))
 
   def object_dir(self):
-    return os.path.join(self.basedir(), OBJECTDIR)
+    return os.path.join(self.controldir(), OBJECTDIR)
+
+  @property
+  def object_store(self):
+    if self._object_store is None:
+        self._object_store = ObjectStore(self.object_dir())
+    return self._object_store
 
   def pack_dir(self):
     return os.path.join(self.object_dir(), PACKDIR)
 
-  def add_pack(self):
-    fd, path = tempfile.mkstemp(dir=self.pack_dir(), suffix=".pack")
-    f = os.fdopen(fd, 'w')
-    def commit():
-       self._move_in_pack(path)
-    return f, commit
-
-  def _move_in_pack(self, path):
-    p = PackData(path)
-    entries = p.sorted_entries()
-    basename = os.path.join(self.pack_dir(), "pack-%s" % iter_sha1(entry[0] for entry in entries))
-    write_pack_index_v2(basename+".idx", entries, p.calculate_checksum())
-    os.rename(path, basename + ".pack")
-
-  def _get_packs(self):
-    if self._packs is None:
-        self._packs = list(load_packs(self.pack_dir()))
-    return self._packs
-
   def _get_ref(self, file):
     f = open(file, 'rb')
     try:
@@ -94,37 +167,51 @@ class Repo(object):
         if ref[-1] == '\n':
           ref = ref[:-1]
         return self.ref(ref)
-      assert len(contents) == 41, 'Invalid ref'
+      assert len(contents) == 41, 'Invalid ref in %s' % file
       return contents[:-1]
     finally:
       f.close()
 
   def ref(self, name):
     for dir in self.ref_locs:
-      file = os.path.join(self.basedir(), dir, name)
+      file = os.path.join(self.controldir(), dir, name)
       if os.path.exists(file):
         return self._get_ref(file)
 
+  def get_refs(self):
+    ret = {}
+    if self.head():
+        ret['HEAD'] = self.head()
+    for dir in ["refs/heads", "refs/tags"]:
+        for name in os.listdir(os.path.join(self.controldir(), dir)):
+          path = os.path.join(self.controldir(), dir, name)
+          if os.path.isfile(path):
+            ret["/".join([dir, name])] = self._get_ref(path)
+    return ret
+
   def set_ref(self, name, value):
-    file = os.path.join(self.basedir(), name)
+    file = os.path.join(self.controldir(), name)
     open(file, 'w').write(value+"\n")
 
   def remove_ref(self, name):
-    file = os.path.join(self.basedir(), name)
+    file = os.path.join(self.controldir(), name)
     if os.path.exists(file):
       os.remove(file)
       return
 
+  def tagdir(self):
+    return os.path.join(self.controldir(), 'refs', 'tags')
+
   def get_tags(self):
     ret = {}
-    for root, dirs, files in os.walk(os.path.join(self.basedir(), 'refs', 'tags')):
+    for root, dirs, files in os.walk(self.tagdir()):
       for name in files:
         ret[name] = self._get_ref(os.path.join(root, name))
     return ret
 
   def heads(self):
     ret = {}
-    for root, dirs, files in os.walk(os.path.join(self.basedir(), 'refs', 'heads')):
+    for root, dirs, files in os.walk(os.path.join(self.controldir(), 'refs', 'heads')):
       for name in files:
         ret[name] = self._get_ref(os.path.join(root, name))
     return ret
@@ -133,22 +220,24 @@ class Repo(object):
     return self.ref('HEAD')
 
   def _get_object(self, sha, cls):
-    assert len(sha) == 40, "Incorrect length sha: %s" % str(sha)
-    dir = sha[:2]
-    file = sha[2:]
-    # Check from object dir
-    path = os.path.join(self.object_dir(), dir, file)
-    if os.path.exists(path):
-      return cls.from_file(path)
-    # Check from packs
-    for pack in self._get_packs():
-        if sha in pack:
-            return pack[sha]
-    # Should this raise instead?
-    return None
+    assert len(sha) in (20, 40)
+    ret = self.get_object(sha)
+    if ret._type != cls._type:
+        if cls is Commit:
+            raise NotCommitError(ret)
+        elif cls is Blob:
+            raise NotBlobError(ret)
+        elif cls is Tree:
+            raise NotTreeError(ret)
+        else:
+            raise Exception("Type invalid: %r != %r" % (ret._type, cls._type))
+    return ret
 
   def get_object(self, sha):
-    return self._get_object(sha, ShaFile)
+    return self.object_store[sha]
+
+  def get_parents(self, sha):
+    return self.commit(sha).parents
 
   def commit(self, sha):
     return self._get_object(sha, Commit)
@@ -176,8 +265,9 @@ class Repo(object):
     history = []
     while pending_commits != []:
       head = pending_commits.pop(0)
-      commit = self.commit(head)
-      if commit is None:
+      try:
+          commit = self.commit(head)
+      except KeyError:
         raise MissingCommitError(head)
       if commit in history:
         continue
@@ -192,6 +282,15 @@ class Repo(object):
     history.reverse()
     return history
 
+  def __repr__(self):
+      return "<Repo at %r>" % self.path
+
+  @classmethod
+  def init(cls, path, mkdir=True):
+      controldir = os.path.join(path, ".git")
+      os.mkdir(controldir)
+      cls.init_bare(controldir)
+
   @classmethod
   def init_bare(cls, path, mkdir=True):
       for d in [["objects"], 
@@ -210,3 +309,5 @@ class Repo(object):
 
   create = init_bare
 
+
+