Fetch all prerequisite revisions.
[jelmer/dulwich-libgit2.git] / dulwich / repo.py
1 # repo.py -- For dealing wih git repositories.
2 # Copyright (C) 2007 James Westby <jw+debian@jameswestby.net>
3 # Copyright (C) 2008 Jelmer Vernooij <jelmer@samba.org>
4
5 # This program is free software; you can redistribute it and/or
6 # modify it under the terms of the GNU General Public License
7 # as published by the Free Software Foundation; version 2
8 # of the License.
9
10 # This program is distributed in the hope that it will be useful,
11 # but WITHOUT ANY WARRANTY; without even the implied warranty of
12 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 # GNU General Public License for more details.
14
15 # You should have received a copy of the GNU General Public License
16 # along with this program; if not, write to the Free Software
17 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
18 # MA  02110-1301, USA.
19
20 import os
21
22 from commit import Commit
23 from errors import (
24         MissingCommitError, 
25         NotBlobError, 
26         NotCommitError, 
27         NotGitRepository,
28         NotTreeError, 
29         )
30 from objects import (
31         ShaFile,
32         Commit,
33         Tree,
34         Blob,
35         )
36 from pack import (
37         iter_sha1, 
38         load_packs, 
39         write_pack_index_v2,
40         PackData, 
41         )
42 import tempfile
43
44 OBJECTDIR = 'objects'
45 PACKDIR = 'pack'
46 SYMREF = 'ref: '
47
48
49 class Tag(object):
50
51     def __init__(self, name, ref):
52         self.name = name
53         self.ref = ref
54
55
56 class Repo(object):
57
58   ref_locs = ['', 'refs', 'refs/tags', 'refs/heads', 'refs/remotes']
59
60   def __init__(self, root):
61     if os.path.isdir(os.path.join(root, ".git", "objects")):
62       self.bare = False
63       self._controldir = os.path.join(root, ".git")
64     elif os.path.isdir(os.path.join(root, "objects")):
65       self.bare = True
66       self._controldir = root
67     else:
68       raise NotGitRepository(root)
69     self.path = root
70     self.tags = [Tag(name, ref) for name, ref in self.get_tags().items()]
71     self._object_store = None
72
73   def controldir(self):
74     return self._controldir
75
76   def find_missing_objects(self, determine_wants, graph_walker, progress):
77     """Fetch the missing objects required for a set of revisions.
78
79     :param determine_wants: Function that takes a dictionary with heads 
80         and returns the list of heads to fetch.
81     :param graph_walker: Object that can iterate over the list of revisions 
82         to fetch and has an "ack" method that will be called to acknowledge 
83         that a revision is present.
84     :param progress: Simple progress function that will be called with 
85         updated progress strings.
86     """
87     wants = determine_wants(self.heads())
88     commits_to_send = set(wants)
89     sha_done = set()
90     ref = graph_walker.next()
91     while ref:
92         sha_done.add(ref)
93         if ref in self.object_store:
94             graph_walker.ack(ref)
95         ref = graph_walker.next()
96     while commits_to_send:
97         sha = commits_to_send.pop()
98         if sha in sha_done:
99             continue
100
101         c = self.commit(sha)
102         assert isinstance(c, Commit)
103         sha_done.add(sha)
104
105         commits_to_send.update([p for p in c.parents if not p in sha_done])
106
107         def parse_tree(tree, sha_done):
108             for mode, name, x in tree.entries():
109                 if not x in sha_done:
110                     try:
111                         t = self.tree(x)
112                         sha_done.add(x)
113                         parse_tree(t, sha_done)
114                     except:
115                         sha_done.add(x)
116
117         treesha = c.tree
118         if treesha not in sha_done:
119             t = self.tree(treesha)
120             sha_done.add(treesha)
121             parse_tree(t, sha_done)
122
123         progress("counting objects: %d\r" % len(sha_done))
124     return sha_done
125
126   def fetch_objects(self, determine_wants, graph_walker, progress):
127     """Fetch the missing objects required for a set of revisions.
128
129     :param determine_wants: Function that takes a dictionary with heads 
130         and returns the list of heads to fetch.
131     :param graph_walker: Object that can iterate over the list of revisions 
132         to fetch and has an "ack" method that will be called to acknowledge 
133         that a revision is present.
134     :param progress: Simple progress function that will be called with 
135         updated progress strings.
136     """
137     shas = self.find_missing_objects(determine_wants, graph_walker, progress)
138     for sha in shas:
139         yield self.get_object(sha)
140
141   def object_dir(self):
142     return os.path.join(self.controldir(), OBJECTDIR)
143
144   @property
145   def object_store(self):
146     if self._object_store is None:
147         self._object_store = ObjectStore(self.object_dir())
148     return self._object_store
149
150   def pack_dir(self):
151     return os.path.join(self.object_dir(), PACKDIR)
152
153   def _get_ref(self, file):
154     f = open(file, 'rb')
155     try:
156       contents = f.read()
157       if contents.startswith(SYMREF):
158         ref = contents[len(SYMREF):]
159         if ref[-1] == '\n':
160           ref = ref[:-1]
161         return self.ref(ref)
162       assert len(contents) == 41, 'Invalid ref in %s' % file
163       return contents[:-1]
164     finally:
165       f.close()
166
167   def ref(self, name):
168     for dir in self.ref_locs:
169       file = os.path.join(self.controldir(), dir, name)
170       if os.path.exists(file):
171         return self._get_ref(file)
172
173   def get_refs(self):
174     ret = {"HEAD": self.head()}
175     for dir in ["refs/heads", "refs/tags"]:
176         for name in os.listdir(os.path.join(self.controldir(), dir)):
177           path = os.path.join(self.controldir(), dir, name)
178           if os.path.isfile(path):
179             ret["/".join([dir, name])] = self._get_ref(path)
180     return ret
181
182   def set_ref(self, name, value):
183     file = os.path.join(self.controldir(), name)
184     open(file, 'w').write(value+"\n")
185
186   def remove_ref(self, name):
187     file = os.path.join(self.controldir(), name)
188     if os.path.exists(file):
189       os.remove(file)
190       return
191
192   def get_tags(self):
193     ret = {}
194     for root, dirs, files in os.walk(os.path.join(self.controldir(), 'refs', 'tags')):
195       for name in files:
196         ret[name] = self._get_ref(os.path.join(root, name))
197     return ret
198
199   def heads(self):
200     ret = {}
201     for root, dirs, files in os.walk(os.path.join(self.controldir(), 'refs', 'heads')):
202       for name in files:
203         ret[name] = self._get_ref(os.path.join(root, name))
204     return ret
205
206   def head(self):
207     return self.ref('HEAD')
208
209   def _get_object(self, sha, cls):
210     assert len(sha) in (20, 40)
211     ret = self.get_object(sha)
212     if ret._type != cls._type:
213         if cls is Commit:
214             raise NotCommitError(ret)
215         elif cls is Blob:
216             raise NotBlobError(ret)
217         elif cls is Tree:
218             raise NotTreeError(ret)
219         else:
220             raise Exception("Type invalid: %r != %r" % (ret._type, cls._type))
221     return ret
222
223   def get_object(self, sha):
224     return self.object_store[sha]
225
226   def get_parents(self, sha):
227     return self.commit(sha).parents
228
229   def commit(self, sha):
230     return self._get_object(sha, Commit)
231
232   def tree(self, sha):
233     return self._get_object(sha, Tree)
234
235   def get_blob(self, sha):
236     return self._get_object(sha, Blob)
237
238   def revision_history(self, head):
239     """Returns a list of the commits reachable from head.
240
241     Returns a list of commit objects. the first of which will be the commit
242     of head, then following theat will be the parents.
243
244     Raises NotCommitError if any no commits are referenced, including if the
245     head parameter isn't the sha of a commit.
246
247     XXX: work out how to handle merges.
248     """
249     # We build the list backwards, as parents are more likely to be older
250     # than children
251     pending_commits = [head]
252     history = []
253     while pending_commits != []:
254       head = pending_commits.pop(0)
255       try:
256           commit = self.commit(head)
257       except KeyError:
258         raise MissingCommitError(head)
259       if commit in history:
260         continue
261       i = 0
262       for known_commit in history:
263         if known_commit.commit_time > commit.commit_time:
264           break
265         i += 1
266       history.insert(i, commit)
267       parents = commit.parents
268       pending_commits += parents
269     history.reverse()
270     return history
271
272   def __repr__(self):
273       return "<Repo at %r>" % self.path
274
275   @classmethod
276   def init_bare(cls, path, mkdir=True):
277       for d in [["objects"], 
278                 ["objects", "info"], 
279                 ["objects", "pack"],
280                 ["branches"],
281                 ["refs"],
282                 ["refs", "tags"],
283                 ["refs", "heads"],
284                 ["hooks"],
285                 ["info"]]:
286           os.mkdir(os.path.join(path, *d))
287       open(os.path.join(path, 'HEAD'), 'w').write("ref: refs/heads/master\n")
288       open(os.path.join(path, 'description'), 'w').write("Unnamed repository")
289       open(os.path.join(path, 'info', 'excludes'), 'w').write("")
290
291   create = init_bare
292
293
294 class ObjectStore(object):
295
296     def __init__(self, path):
297         self.path = path
298         self._packs = None
299
300     def pack_dir(self):
301         return os.path.join(self.path, PACKDIR)
302
303     def __contains__(self, sha):
304         # TODO: This can be more efficient
305         try:
306             self[sha]
307             return True
308         except KeyError:
309             return False
310
311     @property
312     def packs(self):
313         """List with pack objects."""
314         if self._packs is None:
315             self._packs = list(load_packs(self.pack_dir()))
316         return self._packs
317
318     def _get_shafile(self, sha):
319         dir = sha[:2]
320         file = sha[2:]
321         # Check from object dir
322         path = os.path.join(self.path, dir, file)
323         if os.path.exists(path):
324           return ShaFile.from_file(path)
325         return None
326
327     def get_raw(self, sha):
328         """Obtain the raw text for an object.
329         
330         :param sha: Sha for the object.
331         :return: tuple with object type and object contents.
332         """
333         for pack in self.packs:
334             if sha in pack:
335                 return pack.get_raw(sha, self.get_raw)
336         # FIXME: Are pack deltas ever against on-disk shafiles ?
337         ret = self._get_shafile(sha)
338         if ret is not None:
339             return ret.as_raw_string()
340         raise KeyError(sha)
341
342     def __getitem__(self, sha):
343         assert len(sha) == 40, "Incorrect length sha: %s" % str(sha)
344         ret = self._get_shafile(sha)
345         if ret is not None:
346             return ret
347         # Check from packs
348         type, uncomp = self.get_raw(sha)
349         return ShaFile.from_raw_string(type, uncomp)
350
351     def move_in_pack(self, path):
352         """Move a specific file containing a pack into the pack directory.
353
354         :note: The file should be on the same file system as the 
355             packs directory.
356
357         :param path: Path to the pack file.
358         """
359         p = PackData(path)
360         entries = p.sorted_entries(self.get_raw)
361         basename = os.path.join(self.pack_dir(), 
362             "pack-%s" % iter_sha1(entry[0] for entry in entries))
363         write_pack_index_v2(basename+".idx", entries, p.calculate_checksum())
364         os.rename(path, basename + ".pack")
365
366     def add_pack(self):
367         """Add a new pack to this object store. 
368
369         :return: Fileobject to write to and a commit function to 
370             call when the pack is finished.
371         """
372         fd, path = tempfile.mkstemp(dir=self.pack_dir(), suffix=".pack")
373         f = os.fdopen(fd, 'w')
374         def commit():
375             if os.path.getsize(path) > 0:
376                 self.move_in_pack(path)
377         return f, commit