Merge upstream
[jelmer/dulwich-libgit2.git] / dulwich / repo.py
1 # repo.py -- For dealing wih git repositories.
2 # Copyright (C) 2007 James Westby <jw+debian@jameswestby.net>
3 # Copyright (C) 2008 Jelmer Vernooij <jelmer@samba.org>
4
5 # This program is free software; you can redistribute it and/or
6 # modify it under the terms of the GNU General Public License
7 # as published by the Free Software Foundation; version 2
8 # of the License or (at your option) any later version of 
9 # the License.
10
11 # This program is distributed in the hope that it will be useful,
12 # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14 # GNU General Public License for more details.
15
16 # You should have received a copy of the GNU General Public License
17 # along with this program; if not, write to the Free Software
18 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
19 # MA  02110-1301, USA.
20
21 import os, stat
22
23 from commit import Commit
24 from errors import (
25         MissingCommitError, 
26         NotBlobError, 
27         NotCommitError, 
28         NotGitRepository,
29         NotTreeError, 
30         )
31 from object_store import ObjectStore
32 from objects import (
33         ShaFile,
34         Commit,
35         Tree,
36         Blob,
37         )
38
39 OBJECTDIR = 'objects'
40 SYMREF = 'ref: '
41
42
43 class Tags(object):
44
45     def __init__(self, tagdir, tags):
46         self.tagdir = tagdir
47         self.tags = tags
48
49     def __getitem__(self, name):
50         return self.tags[name]
51     
52     def __setitem__(self, name, ref):
53         self.tags[name] = ref
54         f = open(os.path.join(self.tagdir, name), 'wb')
55         try:
56             f.write("%s\n" % ref)
57         finally:
58             f.close()
59
60     def __len__(self):
61         return len(self.tags)
62
63     def iteritems(self):
64         for k in self.tags:
65             yield k, self[k]
66
67
68 class Repo(object):
69
70   ref_locs = ['', 'refs', 'refs/tags', 'refs/heads', 'refs/remotes']
71
72   def __init__(self, root):
73     if os.path.isdir(os.path.join(root, ".git", "objects")):
74       self.bare = False
75       self._controldir = os.path.join(root, ".git")
76     elif os.path.isdir(os.path.join(root, "objects")):
77       self.bare = True
78       self._controldir = root
79     else:
80       raise NotGitRepository(root)
81     self.path = root
82     self.tags = Tags(self.tagdir(), self.get_tags())
83     self._object_store = None
84
85   def controldir(self):
86     return self._controldir
87
88   def find_missing_objects(self, determine_wants, graph_walker, progress):
89     """Fetch the missing objects required for a set of revisions.
90
91     :param determine_wants: Function that takes a dictionary with heads 
92         and returns the list of heads to fetch.
93     :param graph_walker: Object that can iterate over the list of revisions 
94         to fetch and has an "ack" method that will be called to acknowledge 
95         that a revision is present.
96     :param progress: Simple progress function that will be called with 
97         updated progress strings.
98     """
99     wants = determine_wants(self.get_refs())
100     commits_to_send = set(wants)
101     sha_done = set()
102     ref = graph_walker.next()
103     while ref:
104         sha_done.add(ref)
105         if ref in self.object_store:
106             graph_walker.ack(ref)
107         ref = graph_walker.next()
108     while commits_to_send:
109         sha = (commits_to_send.pop(), None)
110         if sha in sha_done:
111             continue
112
113         c = self.commit(sha)
114         assert isinstance(c, Commit)
115         sha_done.add((sha, None))
116
117         commits_to_send.update([p for p in c.parents if not p in sha_done])
118
119         def parse_tree(tree, sha_done):
120             for mode, name, sha in tree.entries():
121                 if sha in sha_done:
122                     continue
123                 if mode & stat.S_IFDIR:
124                     parse_tree(self.tree(sha), sha_done)
125                 sha_done.add((sha, name))
126
127         treesha = c.tree
128         if c.tree not in sha_done:
129             parse_tree(self.tree(c.tree), sha_done)
130             sha_done.add((c.tree, None))
131
132         progress("counting objects: %d\r" % len(sha_done))
133     return sha_done
134
135   def fetch_objects(self, determine_wants, graph_walker, progress):
136     """Fetch the missing objects required for a set of revisions.
137
138     :param determine_wants: Function that takes a dictionary with heads 
139         and returns the list of heads to fetch.
140     :param graph_walker: Object that can iterate over the list of revisions 
141         to fetch and has an "ack" method that will be called to acknowledge 
142         that a revision is present.
143     :param progress: Simple progress function that will be called with 
144         updated progress strings.
145     :return: tuple with number of objects, iterator over objects
146     """
147     shas = self.find_missing_objects(determine_wants, graph_walker, progress)
148     return (len(shas), (self.get_object(sha), path for sha, path in shas))
149
150   def object_dir(self):
151     return os.path.join(self.controldir(), OBJECTDIR)
152
153   @property
154   def object_store(self):
155     if self._object_store is None:
156         self._object_store = ObjectStore(self.object_dir())
157     return self._object_store
158
159   def pack_dir(self):
160     return os.path.join(self.object_dir(), PACKDIR)
161
162   def _get_ref(self, file):
163     f = open(file, 'rb')
164     try:
165       contents = f.read()
166       if contents.startswith(SYMREF):
167         ref = contents[len(SYMREF):]
168         if ref[-1] == '\n':
169           ref = ref[:-1]
170         return self.ref(ref)
171       assert len(contents) == 41, 'Invalid ref in %s' % file
172       return contents[:-1]
173     finally:
174       f.close()
175
176   def ref(self, name):
177     for dir in self.ref_locs:
178       file = os.path.join(self.controldir(), dir, name)
179       if os.path.exists(file):
180         return self._get_ref(file)
181
182   def get_refs(self):
183     ret = {}
184     if self.head():
185         ret['HEAD'] = self.head()
186     for dir in ["refs/heads", "refs/tags"]:
187         for name in os.listdir(os.path.join(self.controldir(), dir)):
188           path = os.path.join(self.controldir(), dir, name)
189           if os.path.isfile(path):
190             ret["/".join([dir, name])] = self._get_ref(path)
191     return ret
192
193   def set_ref(self, name, value):
194     file = os.path.join(self.controldir(), name)
195     open(file, 'w').write(value+"\n")
196
197   def remove_ref(self, name):
198     file = os.path.join(self.controldir(), name)
199     if os.path.exists(file):
200       os.remove(file)
201       return
202
203   def tagdir(self):
204     return os.path.join(self.controldir(), 'refs', 'tags')
205
206   def get_tags(self):
207     ret = {}
208     for root, dirs, files in os.walk(self.tagdir()):
209       for name in files:
210         ret[name] = self._get_ref(os.path.join(root, name))
211     return ret
212
213   def heads(self):
214     ret = {}
215     for root, dirs, files in os.walk(os.path.join(self.controldir(), 'refs', 'heads')):
216       for name in files:
217         ret[name] = self._get_ref(os.path.join(root, name))
218     return ret
219
220   def head(self):
221     return self.ref('HEAD')
222
223   def _get_object(self, sha, cls):
224     assert len(sha) in (20, 40)
225     ret = self.get_object(sha)
226     if ret._type != cls._type:
227         if cls is Commit:
228             raise NotCommitError(ret)
229         elif cls is Blob:
230             raise NotBlobError(ret)
231         elif cls is Tree:
232             raise NotTreeError(ret)
233         else:
234             raise Exception("Type invalid: %r != %r" % (ret._type, cls._type))
235     return ret
236
237   def get_object(self, sha):
238     return self.object_store[sha]
239
240   def get_parents(self, sha):
241     return self.commit(sha).parents
242
243   def commit(self, sha):
244     return self._get_object(sha, Commit)
245
246   def tree(self, sha):
247     return self._get_object(sha, Tree)
248
249   def get_blob(self, sha):
250     return self._get_object(sha, Blob)
251
252   def revision_history(self, head):
253     """Returns a list of the commits reachable from head.
254
255     Returns a list of commit objects. the first of which will be the commit
256     of head, then following theat will be the parents.
257
258     Raises NotCommitError if any no commits are referenced, including if the
259     head parameter isn't the sha of a commit.
260
261     XXX: work out how to handle merges.
262     """
263     # We build the list backwards, as parents are more likely to be older
264     # than children
265     pending_commits = [head]
266     history = []
267     while pending_commits != []:
268       head = pending_commits.pop(0)
269       try:
270           commit = self.commit(head)
271       except KeyError:
272         raise MissingCommitError(head)
273       if commit in history:
274         continue
275       i = 0
276       for known_commit in history:
277         if known_commit.commit_time > commit.commit_time:
278           break
279         i += 1
280       history.insert(i, commit)
281       parents = commit.parents
282       pending_commits += parents
283     history.reverse()
284     return history
285
286   def __repr__(self):
287       return "<Repo at %r>" % self.path
288
289   @classmethod
290   def init(cls, path, mkdir=True):
291       controldir = os.path.join(path, ".git")
292       os.mkdir(controldir)
293       cls.init_bare(controldir)
294
295   @classmethod
296   def init_bare(cls, path, mkdir=True):
297       for d in [["objects"], 
298                 ["objects", "info"], 
299                 ["objects", "pack"],
300                 ["branches"],
301                 ["refs"],
302                 ["refs", "tags"],
303                 ["refs", "heads"],
304                 ["hooks"],
305                 ["info"]]:
306           os.mkdir(os.path.join(path, *d))
307       open(os.path.join(path, 'HEAD'), 'w').write("ref: refs/heads/master\n")
308       open(os.path.join(path, 'description'), 'w').write("Unnamed repository")
309       open(os.path.join(path, 'info', 'excludes'), 'w').write("")
310
311   create = init_bare
312
313
314