Merge John.
[jelmer/dulwich-libgit2.git] / dulwich / repo.py
1 # repo.py -- For dealing wih git repositories.
2 # Copyright (C) 2007 James Westby <jw+debian@jameswestby.net>
3 # Copyright (C) 2008 Jelmer Vernooij <jelmer@samba.org>
4
5 # This program is free software; you can redistribute it and/or
6 # modify it under the terms of the GNU General Public License
7 # as published by the Free Software Foundation; version 2
8 # of the License or (at your option) any later version of 
9 # the License.
10
11 # This program is distributed in the hope that it will be useful,
12 # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14 # GNU General Public License for more details.
15
16 # You should have received a copy of the GNU General Public License
17 # along with this program; if not, write to the Free Software
18 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
19 # MA  02110-1301, USA.
20
21 import os
22
23 from commit import Commit
24 from errors import (
25         MissingCommitError, 
26         NotBlobError, 
27         NotCommitError, 
28         NotGitRepository,
29         NotTreeError, 
30         )
31 from object_store import ObjectStore
32 from objects import (
33         ShaFile,
34         Commit,
35         Tree,
36         Blob,
37         )
38
39 OBJECTDIR = 'objects'
40 SYMREF = 'ref: '
41
42
43 class Tags(object):
44
45     def __init__(self, tagdir, tags):
46         self.tagdir = tagdir
47         self.tags = tags
48
49     def __getitem__(self, name):
50         return self.tags[name]
51     
52     def __setitem__(self, name, ref):
53         self.tags[name] = ref
54         f = open(os.path.join(self.tagdir, name), 'wb')
55         try:
56             f.write("%s\n" % ref)
57         finally:
58             f.close()
59
60     def __len__(self):
61         return len(self.tags)
62
63     def iteritems(self):
64         for k in self.tags:
65             yield k, self[k]
66
67
68 class Repo(object):
69
70   ref_locs = ['', 'refs', 'refs/tags', 'refs/heads', 'refs/remotes']
71
72   def __init__(self, root):
73     if os.path.isdir(os.path.join(root, ".git", "objects")):
74       self.bare = False
75       self._controldir = os.path.join(root, ".git")
76     elif os.path.isdir(os.path.join(root, "objects")):
77       self.bare = True
78       self._controldir = root
79     else:
80       raise NotGitRepository(root)
81     self.path = root
82     self.tags = Tags(self.tagdir(), self.get_tags())
83     self._object_store = None
84
85   def controldir(self):
86     return self._controldir
87
88   def find_missing_objects(self, determine_wants, graph_walker, progress):
89     """Fetch the missing objects required for a set of revisions.
90
91     :param determine_wants: Function that takes a dictionary with heads 
92         and returns the list of heads to fetch.
93     :param graph_walker: Object that can iterate over the list of revisions 
94         to fetch and has an "ack" method that will be called to acknowledge 
95         that a revision is present.
96     :param progress: Simple progress function that will be called with 
97         updated progress strings.
98     """
99     wants = determine_wants(self.get_refs())
100     commits_to_send = set(wants)
101     sha_done = set()
102     ref = graph_walker.next()
103     while ref:
104         sha_done.add(ref)
105         if ref in self.object_store:
106             graph_walker.ack(ref)
107         ref = graph_walker.next()
108     while commits_to_send:
109         sha = commits_to_send.pop()
110         if sha in sha_done:
111             continue
112
113         c = self.commit(sha)
114         assert isinstance(c, Commit)
115         sha_done.add(sha)
116
117         commits_to_send.update([p for p in c.parents if not p in sha_done])
118
119         def parse_tree(tree, sha_done):
120             for mode, name, x in tree.entries():
121                 if not x in sha_done:
122                     try:
123                         t = self.tree(x)
124                         sha_done.add(x)
125                         parse_tree(t, sha_done)
126                     except:
127                         sha_done.add(x)
128
129         treesha = c.tree
130         if treesha not in sha_done:
131             t = self.tree(treesha)
132             sha_done.add(treesha)
133             parse_tree(t, sha_done)
134
135         progress("counting objects: %d\r" % len(sha_done))
136     return sha_done
137
138   def fetch_objects(self, determine_wants, graph_walker, progress):
139     """Fetch the missing objects required for a set of revisions.
140
141     :param determine_wants: Function that takes a dictionary with heads 
142         and returns the list of heads to fetch.
143     :param graph_walker: Object that can iterate over the list of revisions 
144         to fetch and has an "ack" method that will be called to acknowledge 
145         that a revision is present.
146     :param progress: Simple progress function that will be called with 
147         updated progress strings.
148     """
149     shas = self.find_missing_objects(determine_wants, graph_walker, progress)
150     for sha in shas:
151         yield self.get_object(sha)
152
153   def object_dir(self):
154     return os.path.join(self.controldir(), OBJECTDIR)
155
156   @property
157   def object_store(self):
158     if self._object_store is None:
159         self._object_store = ObjectStore(self.object_dir())
160     return self._object_store
161
162   def pack_dir(self):
163     return os.path.join(self.object_dir(), PACKDIR)
164
165   def _get_ref(self, file):
166     f = open(file, 'rb')
167     try:
168       contents = f.read()
169       if contents.startswith(SYMREF):
170         ref = contents[len(SYMREF):]
171         if ref[-1] == '\n':
172           ref = ref[:-1]
173         return self.ref(ref)
174       assert len(contents) == 41, 'Invalid ref in %s' % file
175       return contents[:-1]
176     finally:
177       f.close()
178
179   def ref(self, name):
180     for dir in self.ref_locs:
181       file = os.path.join(self.controldir(), dir, name)
182       if os.path.exists(file):
183         return self._get_ref(file)
184
185   def get_refs(self):
186     ret = {}
187     if self.head():
188         ret['HEAD'] = self.head()
189     for dir in ["refs/heads", "refs/tags"]:
190         for name in os.listdir(os.path.join(self.controldir(), dir)):
191           path = os.path.join(self.controldir(), dir, name)
192           if os.path.isfile(path):
193             ret["/".join([dir, name])] = self._get_ref(path)
194     return ret
195
196   def set_ref(self, name, value):
197     file = os.path.join(self.controldir(), name)
198     open(file, 'w').write(value+"\n")
199
200   def remove_ref(self, name):
201     file = os.path.join(self.controldir(), name)
202     if os.path.exists(file):
203       os.remove(file)
204       return
205
206   def tagdir(self):
207     return os.path.join(self.controldir(), 'refs', 'tags')
208
209   def get_tags(self):
210     ret = {}
211     for root, dirs, files in os.walk(self.tagdir()):
212       for name in files:
213         ret[name] = self._get_ref(os.path.join(root, name))
214     return ret
215
216   def heads(self):
217     ret = {}
218     for root, dirs, files in os.walk(os.path.join(self.controldir(), 'refs', 'heads')):
219       for name in files:
220         ret[name] = self._get_ref(os.path.join(root, name))
221     return ret
222
223   def head(self):
224     return self.ref('HEAD')
225
226   def _get_object(self, sha, cls):
227     assert len(sha) in (20, 40)
228     ret = self.get_object(sha)
229     if ret._type != cls._type:
230         if cls is Commit:
231             raise NotCommitError(ret)
232         elif cls is Blob:
233             raise NotBlobError(ret)
234         elif cls is Tree:
235             raise NotTreeError(ret)
236         else:
237             raise Exception("Type invalid: %r != %r" % (ret._type, cls._type))
238     return ret
239
240   def get_object(self, sha):
241     return self.object_store[sha]
242
243   def get_parents(self, sha):
244     return self.commit(sha).parents
245
246   def commit(self, sha):
247     return self._get_object(sha, Commit)
248
249   def tree(self, sha):
250     return self._get_object(sha, Tree)
251
252   def get_blob(self, sha):
253     return self._get_object(sha, Blob)
254
255   def revision_history(self, head):
256     """Returns a list of the commits reachable from head.
257
258     Returns a list of commit objects. the first of which will be the commit
259     of head, then following theat will be the parents.
260
261     Raises NotCommitError if any no commits are referenced, including if the
262     head parameter isn't the sha of a commit.
263
264     XXX: work out how to handle merges.
265     """
266     # We build the list backwards, as parents are more likely to be older
267     # than children
268     pending_commits = [head]
269     history = []
270     while pending_commits != []:
271       head = pending_commits.pop(0)
272       try:
273           commit = self.commit(head)
274       except KeyError:
275         raise MissingCommitError(head)
276       if commit in history:
277         continue
278       i = 0
279       for known_commit in history:
280         if known_commit.commit_time > commit.commit_time:
281           break
282         i += 1
283       history.insert(i, commit)
284       parents = commit.parents
285       pending_commits += parents
286     history.reverse()
287     return history
288
289   def __repr__(self):
290       return "<Repo at %r>" % self.path
291
292   @classmethod
293   def init(cls, path, mkdir=True):
294       controldir = os.path.join(path, ".git")
295       os.mkdir(controldir)
296       cls.init_bare(controldir)
297
298   @classmethod
299   def init_bare(cls, path, mkdir=True):
300       for d in [["objects"], 
301                 ["objects", "info"], 
302                 ["objects", "pack"],
303                 ["branches"],
304                 ["refs"],
305                 ["refs", "tags"],
306                 ["refs", "heads"],
307                 ["hooks"],
308                 ["info"]]:
309           os.mkdir(os.path.join(path, *d))
310       open(os.path.join(path, 'HEAD'), 'w').write("ref: refs/heads/master\n")
311       open(os.path.join(path, 'description'), 'w').write("Unnamed repository")
312       open(os.path.join(path, 'info', 'excludes'), 'w').write("")
313
314   create = init_bare
315
316
317