Fix docstrings
[jelmer/dulwich-libgit2.git] / dulwich / repo.py
1 # repo.py -- For dealing wih git repositories.
2 # Copyright (C) 2007 James Westby <jw+debian@jameswestby.net>
3 # Copyright (C) 2008-2009 Jelmer Vernooij <jelmer@samba.org>
4
5 # This program is free software; you can redistribute it and/or
6 # modify it under the terms of the GNU General Public License
7 # as published by the Free Software Foundation; version 2
8 # of the License or (at your option) any later version of 
9 # the License.
10
11 # This program is distributed in the hope that it will be useful,
12 # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14 # GNU General Public License for more details.
15
16 # You should have received a copy of the GNU General Public License
17 # along with this program; if not, write to the Free Software
18 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
19 # MA  02110-1301, USA.
20
21 import os
22 import stat
23
24 from dulwich.errors import (
25     MissingCommitError, 
26     NotBlobError, 
27     NotCommitError, 
28     NotGitRepository,
29     NotTreeError, 
30     )
31 from dulwich.object_store import (
32     ObjectStore,
33     )
34 from dulwich.objects import (
35     Blob,
36     Commit,
37     ShaFile,
38     Tag,
39     Tree,
40     )
41
42 OBJECTDIR = 'objects'
43 SYMREF = 'ref: '
44 REFSDIR = 'refs'
45 INDEX_FILENAME = "index"
46
47 class Tags(object):
48     """Tags container."""
49
50     def __init__(self, tagdir, tags):
51         self.tagdir = tagdir
52         self.tags = tags
53
54     def __getitem__(self, name):
55         return self.tags[name]
56     
57     def __setitem__(self, name, ref):
58         self.tags[name] = ref
59         f = open(os.path.join(self.tagdir, name), 'wb')
60         try:
61             f.write("%s\n" % ref)
62         finally:
63             f.close()
64
65     def __len__(self):
66         return len(self.tags)
67
68     def iteritems(self):
69         for k in self.tags:
70             yield k, self[k]
71
72
73 def read_packed_refs(f):
74     l = f.readline()
75     for l in f.readlines():
76         if l[0] == "#":
77             # Comment
78             continue
79         if l[0] == "^":
80             # FIXME: Return somehow
81             continue
82         yield tuple(l.rstrip("\n").split(" ", 2))
83
84
85 class MissingObjectFinder(object):
86     """Find the objects missing from another git repository.
87
88     :param object_store: Object store containing at least all objects to be 
89         sent
90     :param wants: SHA1s of commits to send
91     :param graph_walker: graph walker object used to see what the remote 
92         repo has and misses
93     :param progress: Optional function to report progress to.
94     """
95
96     def __init__(self, object_store, wants, graph_walker, progress=None):
97         self.sha_done = set()
98         self.objects_to_send = set([(w, None) for w in wants])
99         self.object_store = object_store
100         if progress is None:
101             self.progress = lambda x: None
102         else:
103             self.progress = progress
104         ref = graph_walker.next()
105         while ref:
106             if ref in self.object_store:
107                 graph_walker.ack(ref)
108             ref = graph_walker.next()
109
110     def add_todo(self, entries):
111         self.objects_to_send.update([e for e in entries if not e in self.sha_done])
112
113     def parse_tree(self, tree):
114         self.add_todo([(sha, name) for (mode, name, sha) in tree.entries()])
115
116     def parse_commit(self, commit):
117         self.add_todo([(commit.tree, "")])
118         self.add_todo([(p, None) for p in commit.parents])
119
120     def parse_tag(self, tag):
121         self.add_todo([(tag.object[1], None)])
122
123     def next(self):
124         if not self.objects_to_send:
125             return None
126         (sha, name) = self.objects_to_send.pop()
127         o = self.object_store[sha]
128         if isinstance(o, Commit):
129             self.parse_commit(o)
130         elif isinstance(o, Tree):
131             self.parse_tree(o)
132         elif isinstance(o, Tag):
133             self.parse_tag(o)
134         self.sha_done.add((sha, name))
135         self.progress("counting objects: %d\r" % len(self.sha_done))
136         return (sha, name)
137
138
139 class Repo(object):
140     """A local git repository."""
141
142     ref_locs = ['', REFSDIR, 'refs/tags', 'refs/heads', 'refs/remotes']
143
144     def __init__(self, root):
145         if os.path.isdir(os.path.join(root, ".git", OBJECTDIR)):
146             self.bare = False
147             self._controldir = os.path.join(root, ".git")
148         elif os.path.isdir(os.path.join(root, OBJECTDIR)):
149             self.bare = True
150             self._controldir = root
151         else:
152             raise NotGitRepository(root)
153         self.path = root
154         self.tags = Tags(self.tagdir(), self.get_tags())
155         self._object_store = None
156
157     def controldir(self):
158         """Return the path of the control directory."""
159         return self._controldir
160
161     def index_path(self):
162         return os.path.join(self.controldir(), INDEX_FILENAME)
163
164     def open_index(self):
165         """Open the index for this repository."""
166         from dulwich.index import Index
167         return Index(self.index_path())
168
169     def has_index(self):
170         """Check if an index is present."""
171         return os.path.exists(self.index_path())
172
173     def find_missing_objects(self, determine_wants, graph_walker, progress):
174         """Find the missing objects required for a set of revisions.
175
176         :param determine_wants: Function that takes a dictionary with heads 
177             and returns the list of heads to fetch.
178         :param graph_walker: Object that can iterate over the list of revisions 
179             to fetch and has an "ack" method that will be called to acknowledge 
180             that a revision is present.
181         :param progress: Simple progress function that will be called with 
182             updated progress strings.
183         """
184         wants = determine_wants(self.get_refs())
185         return iter(MissingObjectFinder(self.object_store, wants, graph_walker, 
186                 progress).next, None)
187
188     def fetch_objects(self, determine_wants, graph_walker, progress):
189         """Fetch the missing objects required for a set of revisions.
190
191         :param determine_wants: Function that takes a dictionary with heads 
192             and returns the list of heads to fetch.
193         :param graph_walker: Object that can iterate over the list of revisions 
194             to fetch and has an "ack" method that will be called to acknowledge 
195             that a revision is present.
196         :param progress: Simple progress function that will be called with 
197             updated progress strings.
198         :return: tuple with number of objects, iterator over objects
199         """
200         return self.object_store.iter_shas(
201             self.find_missing_objects(determine_wants, graph_walker, progress))
202
203     def object_dir(self):
204         return os.path.join(self.controldir(), OBJECTDIR)
205
206     @property
207     def object_store(self):
208         if self._object_store is None:
209             self._object_store = ObjectStore(self.object_dir())
210         return self._object_store
211
212     def pack_dir(self):
213         return os.path.join(self.object_dir(), PACKDIR)
214
215     def _get_ref(self, file):
216         f = open(file, 'rb')
217         try:
218             contents = f.read()
219             if contents.startswith(SYMREF):
220                 ref = contents[len(SYMREF):]
221                 if ref[-1] == '\n':
222                     ref = ref[:-1]
223                 return self.ref(ref)
224             assert len(contents) == 41, 'Invalid ref in %s' % file
225             return contents[:-1]
226         finally:
227             f.close()
228
229     def ref(self, name):
230         """Return the SHA1 a ref is pointing to."""
231         for dir in self.ref_locs:
232             file = os.path.join(self.controldir(), dir, name)
233             if os.path.exists(file):
234                 return self._get_ref(file)
235         packed_refs = self.get_packed_refs()
236         if name in packed_refs:
237             return packed_refs[name]
238
239     def get_refs(self):
240         ret = {}
241         if self.head():
242             ret['HEAD'] = self.head()
243         for dir in ["refs/heads", "refs/tags"]:
244             for name in os.listdir(os.path.join(self.controldir(), dir)):
245                 path = os.path.join(self.controldir(), dir, name)
246                 if os.path.isfile(path):
247                     ret["/".join([dir, name])] = self._get_ref(path)
248         ret.update(self.get_packed_refs())
249         return ret
250
251     def get_packed_refs(self):
252         path = os.path.join(self.controldir(), 'packed-refs')
253         if not os.path.exists(path):
254             return {}
255         ret = {}
256         f = open(path, 'r')
257         try:
258             for entry in read_packed_refs(f):
259                 ret[entry[1]] = entry[0]
260             return ret
261         finally:
262             f.close()
263
264     def set_ref(self, name, value):
265         file = os.path.join(self.controldir(), name)
266         dirpath = os.path.dirname(file)
267         if not os.path.exists(dirpath):
268             os.makedirs(dirpath)
269         f = open(file, 'w')
270         try:
271             f.write(value+"\n")
272         finally:
273             f.close()
274
275     def remove_ref(self, name):
276         file = os.path.join(self.controldir(), name)
277         if os.path.exists(file):
278             os.remove(file)
279
280     def tagdir(self):
281         """Tag directory."""
282         return os.path.join(self.controldir(), REFSDIR, 'tags')
283
284     def get_tags(self):
285         ret = {}
286         for root, dirs, files in os.walk(self.tagdir()):
287             for name in files:
288                 ret[name] = self._get_ref(os.path.join(root, name))
289         return ret
290
291     def heads(self):
292         ret = {}
293         for root, dirs, files in os.walk(os.path.join(self.controldir(), REFSDIR, 'heads')):
294             for name in files:
295                 ret[name] = self._get_ref(os.path.join(root, name))
296         return ret
297
298     def head(self):
299         return self.ref('HEAD')
300
301     def _get_object(self, sha, cls):
302         assert len(sha) in (20, 40)
303         ret = self.get_object(sha)
304         if ret._type != cls._type:
305             if cls is Commit:
306                 raise NotCommitError(ret)
307             elif cls is Blob:
308                 raise NotBlobError(ret)
309             elif cls is Tree:
310                 raise NotTreeError(ret)
311             else:
312                 raise Exception("Type invalid: %r != %r" % (ret._type, cls._type))
313         return ret
314
315     def get_object(self, sha):
316         return self.object_store[sha]
317
318     def get_parents(self, sha):
319         return self.commit(sha).parents
320
321     def commit(self, sha):
322         return self._get_object(sha, Commit)
323
324     def tree(self, sha):
325         return self._get_object(sha, Tree)
326
327     def tag(self, sha):
328         return self._get_object(sha, Tag)
329
330     def get_blob(self, sha):
331         return self._get_object(sha, Blob)
332
333     def revision_history(self, head):
334         """Returns a list of the commits reachable from head.
335
336         Returns a list of commit objects. the first of which will be the commit
337         of head, then following theat will be the parents.
338
339         Raises NotCommitError if any no commits are referenced, including if the
340         head parameter isn't the sha of a commit.
341
342         XXX: work out how to handle merges.
343         """
344         # We build the list backwards, as parents are more likely to be older
345         # than children
346         pending_commits = [head]
347         history = []
348         while pending_commits != []:
349             head = pending_commits.pop(0)
350             try:
351                 commit = self.commit(head)
352             except KeyError:
353                 raise MissingCommitError(head)
354             if commit in history:
355                 continue
356             i = 0
357             for known_commit in history:
358                 if known_commit.commit_time > commit.commit_time:
359                     break
360                 i += 1
361             history.insert(i, commit)
362             parents = commit.parents
363             pending_commits += parents
364         history.reverse()
365         return history
366
367     def __repr__(self):
368         return "<Repo at %r>" % self.path
369
370     @classmethod
371     def init(cls, path, mkdir=True):
372         controldir = os.path.join(path, ".git")
373         os.mkdir(controldir)
374         cls.init_bare(controldir)
375
376     @classmethod
377     def init_bare(cls, path, mkdir=True):
378         for d in [[OBJECTDIR], 
379                   [OBJECTDIR, "info"], 
380                   [OBJECTDIR, "pack"],
381                   ["branches"],
382                   [REFSDIR],
383                   ["refs", "tags"],
384                   ["refs", "heads"],
385                   ["hooks"],
386                   ["info"]]:
387             os.mkdir(os.path.join(path, *d))
388         open(os.path.join(path, 'HEAD'), 'w').write("ref: refs/heads/master\n")
389         open(os.path.join(path, 'description'), 'w').write("Unnamed repository")
390         open(os.path.join(path, 'info', 'excludes'), 'w').write("")
391
392     create = init_bare
393