Fix call.
[jelmer/dulwich-libgit2.git] / dulwich / repo.py
1 # repo.py -- For dealing wih git repositories.
2 # Copyright (C) 2007 James Westby <jw+debian@jameswestby.net>
3 # Copyright (C) 2008 Jelmer Vernooij <jelmer@samba.org>
4
5 # This program is free software; you can redistribute it and/or
6 # modify it under the terms of the GNU General Public License
7 # as published by the Free Software Foundation; version 2
8 # of the License.
9
10 # This program is distributed in the hope that it will be useful,
11 # but WITHOUT ANY WARRANTY; without even the implied warranty of
12 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 # GNU General Public License for more details.
14
15 # You should have received a copy of the GNU General Public License
16 # along with this program; if not, write to the Free Software
17 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
18 # MA  02110-1301, USA.
19
20 import os
21
22 from commit import Commit
23 from errors import (
24         MissingCommitError, 
25         NotBlobError, 
26         NotCommitError, 
27         NotGitRepository,
28         NotTreeError, 
29         )
30 from object_store import ObjectStore
31 from objects import (
32         ShaFile,
33         Commit,
34         Tree,
35         Blob,
36         )
37
38 OBJECTDIR = 'objects'
39 SYMREF = 'ref: '
40
41
42 class Tags(object):
43
44     def __init__(self, tagdir, tags):
45         self.tagdir = tagdir
46         self.tags = tags
47
48     def __getitem__(self, name):
49         return self.tags[name]
50     
51     def __setitem__(self, name, ref):
52         self.tags[name] = ref
53         f = open(os.path.join(self.tagdir, name), 'wb')
54         try:
55             f.write("%s\n" % ref)
56         finally:
57             f.close()
58
59     def __len__(self):
60         return len(self.tags)
61
62     def iteritems(self):
63         for k in self.tags:
64             yield k, self[k]
65
66
67 class Repo(object):
68
69   ref_locs = ['', 'refs', 'refs/tags', 'refs/heads', 'refs/remotes']
70
71   def __init__(self, root):
72     if os.path.isdir(os.path.join(root, ".git", "objects")):
73       self.bare = False
74       self._controldir = os.path.join(root, ".git")
75     elif os.path.isdir(os.path.join(root, "objects")):
76       self.bare = True
77       self._controldir = root
78     else:
79       raise NotGitRepository(root)
80     self.path = root
81     self.tags = Tags(self.tagdir(), self.get_tags())
82     self._object_store = None
83
84   def controldir(self):
85     return self._controldir
86
87   def find_missing_objects(self, determine_wants, graph_walker, progress):
88     """Fetch the missing objects required for a set of revisions.
89
90     :param determine_wants: Function that takes a dictionary with heads 
91         and returns the list of heads to fetch.
92     :param graph_walker: Object that can iterate over the list of revisions 
93         to fetch and has an "ack" method that will be called to acknowledge 
94         that a revision is present.
95     :param progress: Simple progress function that will be called with 
96         updated progress strings.
97     """
98     wants = determine_wants(self.get_refs())
99     commits_to_send = set(wants)
100     sha_done = set()
101     ref = graph_walker.next()
102     while ref:
103         sha_done.add(ref)
104         if ref in self.object_store:
105             graph_walker.ack(ref)
106         ref = graph_walker.next()
107     while commits_to_send:
108         sha = commits_to_send.pop()
109         if sha in sha_done:
110             continue
111
112         c = self.commit(sha)
113         assert isinstance(c, Commit)
114         sha_done.add(sha)
115
116         commits_to_send.update([p for p in c.parents if not p in sha_done])
117
118         def parse_tree(tree, sha_done):
119             for mode, name, x in tree.entries():
120                 if not x in sha_done:
121                     try:
122                         t = self.tree(x)
123                         sha_done.add(x)
124                         parse_tree(t, sha_done)
125                     except:
126                         sha_done.add(x)
127
128         treesha = c.tree
129         if treesha not in sha_done:
130             t = self.tree(treesha)
131             sha_done.add(treesha)
132             parse_tree(t, sha_done)
133
134         progress("counting objects: %d\r" % len(sha_done))
135     return sha_done
136
137   def fetch_objects(self, determine_wants, graph_walker, progress):
138     """Fetch the missing objects required for a set of revisions.
139
140     :param determine_wants: Function that takes a dictionary with heads 
141         and returns the list of heads to fetch.
142     :param graph_walker: Object that can iterate over the list of revisions 
143         to fetch and has an "ack" method that will be called to acknowledge 
144         that a revision is present.
145     :param progress: Simple progress function that will be called with 
146         updated progress strings.
147     """
148     shas = self.find_missing_objects(determine_wants, graph_walker, progress)
149     for sha in shas:
150         yield self.get_object(sha)
151
152   def object_dir(self):
153     return os.path.join(self.controldir(), OBJECTDIR)
154
155   @property
156   def object_store(self):
157     if self._object_store is None:
158         self._object_store = ObjectStore(self.object_dir())
159     return self._object_store
160
161   def pack_dir(self):
162     return os.path.join(self.object_dir(), PACKDIR)
163
164   def _get_ref(self, file):
165     f = open(file, 'rb')
166     try:
167       contents = f.read()
168       if contents.startswith(SYMREF):
169         ref = contents[len(SYMREF):]
170         if ref[-1] == '\n':
171           ref = ref[:-1]
172         return self.ref(ref)
173       assert len(contents) == 41, 'Invalid ref in %s' % file
174       return contents[:-1]
175     finally:
176       f.close()
177
178   def ref(self, name):
179     for dir in self.ref_locs:
180       file = os.path.join(self.controldir(), dir, name)
181       if os.path.exists(file):
182         return self._get_ref(file)
183
184   def get_refs(self):
185     ret = {}
186     if self.head():
187         ret['HEAD'] = self.head()
188     for dir in ["refs/heads", "refs/tags"]:
189         for name in os.listdir(os.path.join(self.controldir(), dir)):
190           path = os.path.join(self.controldir(), dir, name)
191           if os.path.isfile(path):
192             ret["/".join([dir, name])] = self._get_ref(path)
193     return ret
194
195   def set_ref(self, name, value):
196     file = os.path.join(self.controldir(), name)
197     open(file, 'w').write(value+"\n")
198
199   def remove_ref(self, name):
200     file = os.path.join(self.controldir(), name)
201     if os.path.exists(file):
202       os.remove(file)
203       return
204
205   def tagdir(self):
206     return os.path.join(self.controldir(), 'refs', 'tags')
207
208   def get_tags(self):
209     ret = {}
210     for root, dirs, files in os.walk(self.tagdir()):
211       for name in files:
212         ret[name] = self._get_ref(os.path.join(root, name))
213     return ret
214
215   def heads(self):
216     ret = {}
217     for root, dirs, files in os.walk(os.path.join(self.controldir(), 'refs', 'heads')):
218       for name in files:
219         ret[name] = self._get_ref(os.path.join(root, name))
220     return ret
221
222   def head(self):
223     return self.ref('HEAD')
224
225   def _get_object(self, sha, cls):
226     assert len(sha) in (20, 40)
227     ret = self.get_object(sha)
228     if ret._type != cls._type:
229         if cls is Commit:
230             raise NotCommitError(ret)
231         elif cls is Blob:
232             raise NotBlobError(ret)
233         elif cls is Tree:
234             raise NotTreeError(ret)
235         else:
236             raise Exception("Type invalid: %r != %r" % (ret._type, cls._type))
237     return ret
238
239   def get_object(self, sha):
240     return self.object_store[sha]
241
242   def get_parents(self, sha):
243     return self.commit(sha).parents
244
245   def commit(self, sha):
246     return self._get_object(sha, Commit)
247
248   def tree(self, sha):
249     return self._get_object(sha, Tree)
250
251   def get_blob(self, sha):
252     return self._get_object(sha, Blob)
253
254   def revision_history(self, head):
255     """Returns a list of the commits reachable from head.
256
257     Returns a list of commit objects. the first of which will be the commit
258     of head, then following theat will be the parents.
259
260     Raises NotCommitError if any no commits are referenced, including if the
261     head parameter isn't the sha of a commit.
262
263     XXX: work out how to handle merges.
264     """
265     # We build the list backwards, as parents are more likely to be older
266     # than children
267     pending_commits = [head]
268     history = []
269     while pending_commits != []:
270       head = pending_commits.pop(0)
271       try:
272           commit = self.commit(head)
273       except KeyError:
274         raise MissingCommitError(head)
275       if commit in history:
276         continue
277       i = 0
278       for known_commit in history:
279         if known_commit.commit_time > commit.commit_time:
280           break
281         i += 1
282       history.insert(i, commit)
283       parents = commit.parents
284       pending_commits += parents
285     history.reverse()
286     return history
287
288   def __repr__(self):
289       return "<Repo at %r>" % self.path
290
291   @classmethod
292   def init(cls, path, mkdir=True):
293       controldir = os.path.join(path, ".git")
294       os.mkdir(controldir)
295       cls.init_bare(controldir)
296
297   @classmethod
298   def init_bare(cls, path, mkdir=True):
299       for d in [["objects"], 
300                 ["objects", "info"], 
301                 ["objects", "pack"],
302                 ["branches"],
303                 ["refs"],
304                 ["refs", "tags"],
305                 ["refs", "heads"],
306                 ["hooks"],
307                 ["info"]]:
308           os.mkdir(os.path.join(path, *d))
309       open(os.path.join(path, 'HEAD'), 'w').write("ref: refs/heads/master\n")
310       open(os.path.join(path, 'description'), 'w').write("Unnamed repository")
311       open(os.path.join(path, 'info', 'excludes'), 'w').write("")
312
313   create = init_bare
314
315
316