Skip yielding objects until we've figured out the sha's.
[jelmer/dulwich-libgit2.git] / dulwich / repo.py
1 # repo.py -- For dealing wih git repositories.
2 # Copyright (C) 2007 James Westby <jw+debian@jameswestby.net>
3 # Copyright (C) 2008 Jelmer Vernooij <jelmer@samba.org>
4
5 # This program is free software; you can redistribute it and/or
6 # modify it under the terms of the GNU General Public License
7 # as published by the Free Software Foundation; version 2
8 # of the License.
9
10 # This program is distributed in the hope that it will be useful,
11 # but WITHOUT ANY WARRANTY; without even the implied warranty of
12 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 # GNU General Public License for more details.
14
15 # You should have received a copy of the GNU General Public License
16 # along with this program; if not, write to the Free Software
17 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
18 # MA  02110-1301, USA.
19
20 import os
21
22 from commit import Commit
23 from errors import MissingCommitError, NotBlobError, NotTreeError, NotCommitError
24 from objects import (ShaFile,
25                      Commit,
26                      Tree,
27                      Blob,
28                      )
29 from pack import load_packs, iter_sha1, PackData, write_pack_index_v2
30 import tempfile
31
32 OBJECTDIR = 'objects'
33 PACKDIR = 'pack'
34 SYMREF = 'ref: '
35
36
37 class Tag(object):
38
39     def __init__(self, name, ref):
40         self.name = name
41         self.ref = ref
42
43
44 class Repo(object):
45
46   ref_locs = ['', 'refs', 'refs/tags', 'refs/heads', 'refs/remotes']
47
48   def __init__(self, root):
49     controldir = os.path.join(root, ".git")
50     if os.path.exists(os.path.join(controldir, "objects")):
51       self.bare = False
52       self._basedir = controldir
53     else:
54       self.bare = True
55       self._basedir = root
56     self.path = controldir
57     self.tags = [Tag(name, ref) for name, ref in self.get_tags().items()]
58     self._object_store = None
59
60   def basedir(self):
61     return self._basedir
62
63   def fetch_objects(self, determine_wants, graph_walker, progress):
64     wants = determine_wants(self.heads())
65     commits_to_send = []
66     ref = graph_walker.next()
67     while ref:
68         commits_to_send.append(ref)
69         if ref in self.object_store:
70             graph_walker.ack(ref)
71         ref = graph_walker.next()
72     sha_done = set()
73     for sha in commits_to_send:
74         if sha in sha_done:
75             continue
76
77         c = self.commit(sha)
78         sha_done.add(sha)
79
80         def parse_tree(tree, sha_done):
81             for mode, name, x in tree.entries():
82                 if not x in sha_done:
83                     try:
84                         t = self.tree(x)
85                         sha_done.add(x)
86                         parse_tree(t, sha_done)
87                     except:
88                         sha_done.append(x)
89
90         treesha = c.tree
91         if treesha not in sha_done:
92             t = self.tree(treesha)
93             sha_done.add(treesha)
94             parse_tree(t, sha_done)
95
96         progress("counting objects: %d\r" % len(sha_done))
97
98         for sha in sha_done:
99             yield self.get_object(sha)
100
101   def object_dir(self):
102     return os.path.join(self.basedir(), OBJECTDIR)
103
104   @property
105   def object_store(self):
106     if self._object_store is None:
107         self._object_store = ObjectStore(self.object_dir())
108     return self._object_store
109
110   def pack_dir(self):
111     return os.path.join(self.object_dir(), PACKDIR)
112
113   def _get_ref(self, file):
114     f = open(file, 'rb')
115     try:
116       contents = f.read()
117       if contents.startswith(SYMREF):
118         ref = contents[len(SYMREF):]
119         if ref[-1] == '\n':
120           ref = ref[:-1]
121         return self.ref(ref)
122       assert len(contents) == 41, 'Invalid ref'
123       return contents[:-1]
124     finally:
125       f.close()
126
127   def ref(self, name):
128     for dir in self.ref_locs:
129       file = os.path.join(self.basedir(), dir, name)
130       if os.path.exists(file):
131         return self._get_ref(file)
132
133   def set_ref(self, name, value):
134     file = os.path.join(self.basedir(), name)
135     open(file, 'w').write(value+"\n")
136
137   def remove_ref(self, name):
138     file = os.path.join(self.basedir(), name)
139     if os.path.exists(file):
140       os.remove(file)
141       return
142
143   def get_tags(self):
144     ret = {}
145     for root, dirs, files in os.walk(os.path.join(self.basedir(), 'refs', 'tags')):
146       for name in files:
147         ret[name] = self._get_ref(os.path.join(root, name))
148     return ret
149
150   def heads(self):
151     ret = {}
152     for root, dirs, files in os.walk(os.path.join(self.basedir(), 'refs', 'heads')):
153       for name in files:
154         ret[name] = self._get_ref(os.path.join(root, name))
155     return ret
156
157   def head(self):
158     return self.ref('HEAD')
159
160   def _get_object(self, sha, cls):
161     ret = self.get_object(sha)
162     if ret._type != cls._type:
163         if cls is Commit:
164             raise NotCommitError(ret)
165         elif cls is Blob:
166             raise NotBlobError(ret)
167         elif cls is Tree:
168             raise NotTreeError(ret)
169         else:
170             raise Exception("Type invalid: %r != %r" % (ret._type, cls._type))
171     return ret
172
173   def get_object(self, sha):
174     return self.object_store[sha]
175
176   def get_parents(self, sha):
177     return self.commit(sha).parents
178
179   def commit(self, sha):
180     return self._get_object(sha, Commit)
181
182   def tree(self, sha):
183     return self._get_object(sha, Tree)
184
185   def get_blob(self, sha):
186     return self._get_object(sha, Blob)
187
188   def revision_history(self, head):
189     """Returns a list of the commits reachable from head.
190
191     Returns a list of commit objects. the first of which will be the commit
192     of head, then following theat will be the parents.
193
194     Raises NotCommitError if any no commits are referenced, including if the
195     head parameter isn't the sha of a commit.
196
197     XXX: work out how to handle merges.
198     """
199     # We build the list backwards, as parents are more likely to be older
200     # than children
201     pending_commits = [head]
202     history = []
203     while pending_commits != []:
204       head = pending_commits.pop(0)
205       try:
206           commit = self.commit(head)
207       except KeyError:
208         raise MissingCommitError(head)
209       if commit in history:
210         continue
211       i = 0
212       for known_commit in history:
213         if known_commit.commit_time > commit.commit_time:
214           break
215         i += 1
216       history.insert(i, commit)
217       parents = commit.parents
218       pending_commits += parents
219     history.reverse()
220     return history
221
222   @classmethod
223   def init_bare(cls, path, mkdir=True):
224       for d in [["objects"], 
225                 ["objects", "info"], 
226                 ["objects", "pack"],
227                 ["branches"],
228                 ["refs"],
229                 ["refs", "tags"],
230                 ["refs", "heads"],
231                 ["hooks"],
232                 ["info"]]:
233           os.mkdir(os.path.join(path, *d))
234       open(os.path.join(path, 'HEAD'), 'w').write("ref: refs/heads/master\n")
235       open(os.path.join(path, 'description'), 'w').write("Unnamed repository")
236       open(os.path.join(path, 'info', 'excludes'), 'w').write("")
237
238   create = init_bare
239
240
241 class ObjectStore(object):
242
243     def __init__(self, path):
244         self.path = path
245         self._packs = None
246
247     def pack_dir(self):
248         return os.path.join(self.path, PACKDIR)
249
250     def __contains__(self, sha):
251         # TODO: This can be more efficient
252         try:
253             self[sha]
254             return True
255         except KeyError:
256             return False
257
258     @property
259     def packs(self):
260         if self._packs is None:
261             self._packs = list(load_packs(self.pack_dir()))
262         return self._packs
263
264     def _get_shafile(self, sha):
265         dir = sha[:2]
266         file = sha[2:]
267         # Check from object dir
268         path = os.path.join(self.path, dir, file)
269         if os.path.exists(path):
270           return ShaFile.from_file(path)
271         return None
272
273     def get_raw(self, sha):
274         for pack in self.packs:
275             if sha in pack:
276                 return pack.get_raw(sha, self.get_raw)
277         # FIXME: Are pack deltas ever against on-disk shafiles ?
278         ret = self._get_shafile(sha)
279         if ret is not None:
280             return ret.as_raw_string()
281         raise KeyError(sha)
282
283     def __getitem__(self, sha):
284         assert len(sha) == 40, "Incorrect length sha: %s" % str(sha)
285         ret = self._get_shafile(sha)
286         if ret is not None:
287             return ret
288         # Check from packs
289         type, uncomp = self.get_raw(sha)
290         return ShaFile.from_raw_string(type, uncomp)
291
292     def move_in_pack(self, path):
293         p = PackData(path)
294         entries = p.sorted_entries(self.get_raw)
295         basename = os.path.join(self.pack_dir(), "pack-%s" % iter_sha1(entry[0] for entry in entries))
296         write_pack_index_v2(basename+".idx", entries, p.calculate_checksum())
297         os.rename(path, basename + ".pack")
298
299     def add_pack(self):
300         fd, path = tempfile.mkstemp(dir=self.pack_dir(), suffix=".pack")
301         f = os.fdopen(fd, 'w')
302         def commit():
303             if os.path.getsize(path) > 0:
304                 self.move_in_pack(path)
305         return f, commit