Move some of the finding missing objects code to object_store.
[jelmer/dulwich-libgit2.git] / dulwich / repo.py
1 # repo.py -- For dealing wih git repositories.
2 # Copyright (C) 2007 James Westby <jw+debian@jameswestby.net>
3 # Copyright (C) 2008-2009 Jelmer Vernooij <jelmer@samba.org>
4
5 # This program is free software; you can redistribute it and/or
6 # modify it under the terms of the GNU General Public License
7 # as published by the Free Software Foundation; version 2
8 # of the License or (at your option) any later version of 
9 # the License.
10
11 # This program is distributed in the hope that it will be useful,
12 # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14 # GNU General Public License for more details.
15
16 # You should have received a copy of the GNU General Public License
17 # along with this program; if not, write to the Free Software
18 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
19 # MA  02110-1301, USA.
20
21 import os
22 import stat
23
24 from dulwich.errors import (
25     MissingCommitError, 
26     NotBlobError, 
27     NotCommitError, 
28     NotGitRepository,
29     NotTreeError, 
30     )
31 from dulwich.object_store import (
32     ObjectStore,
33     )
34 from dulwich.objects import (
35     Blob,
36     Commit,
37     ShaFile,
38     Tag,
39     Tree,
40     )
41
42 OBJECTDIR = 'objects'
43 SYMREF = 'ref: '
44 REFSDIR = 'refs'
45 INDEX_FILENAME = "index"
46
47 class Tags(object):
48     """Tags container."""
49
50     def __init__(self, tagdir, tags):
51         self.tagdir = tagdir
52         self.tags = tags
53
54     def __getitem__(self, name):
55         return self.tags[name]
56     
57     def __setitem__(self, name, ref):
58         self.tags[name] = ref
59         f = open(os.path.join(self.tagdir, name), 'wb')
60         try:
61             f.write("%s\n" % ref)
62         finally:
63             f.close()
64
65     def __len__(self):
66         return len(self.tags)
67
68     def iteritems(self):
69         for k in self.tags:
70             yield k, self[k]
71
72
73 def read_packed_refs(f):
74     """Read a packed refs file.
75
76     Yields tuples with ref names and SHA1s.
77
78     :param f: file-like object to read from
79     """
80     l = f.readline()
81     for l in f.readlines():
82         if l[0] == "#":
83             # Comment
84             continue
85         if l[0] == "^":
86             # FIXME: Return somehow
87             continue
88         yield tuple(l.rstrip("\n").split(" ", 2))
89
90
91 class Repo(object):
92     """A local git repository."""
93
94     ref_locs = ['', REFSDIR, 'refs/tags', 'refs/heads', 'refs/remotes']
95
96     def __init__(self, root):
97         if os.path.isdir(os.path.join(root, ".git", OBJECTDIR)):
98             self.bare = False
99             self._controldir = os.path.join(root, ".git")
100         elif os.path.isdir(os.path.join(root, OBJECTDIR)):
101             self.bare = True
102             self._controldir = root
103         else:
104             raise NotGitRepository(root)
105         self.path = root
106         self.tags = Tags(self.tagdir(), self.get_tags())
107         self._object_store = None
108
109     def controldir(self):
110         """Return the path of the control directory."""
111         return self._controldir
112
113     def index_path(self):
114         return os.path.join(self.controldir(), INDEX_FILENAME)
115
116     def open_index(self):
117         """Open the index for this repository."""
118         from dulwich.index import Index
119         return Index(self.index_path())
120
121     def has_index(self):
122         """Check if an index is present."""
123         return os.path.exists(self.index_path())
124
125     def find_missing_objects(self, determine_wants, graph_walker, progress):
126         """Find the missing objects required for a set of revisions.
127
128         :param determine_wants: Function that takes a dictionary with heads 
129             and returns the list of heads to fetch.
130         :param graph_walker: Object that can iterate over the list of revisions 
131             to fetch and has an "ack" method that will be called to acknowledge 
132             that a revision is present.
133         :param progress: Simple progress function that will be called with 
134             updated progress strings.
135         :return: Iterator over (sha, path) pairs.
136         """
137         wants = determine_wants(self.get_refs())
138         return self.object_store.find_missing_objects(wants, 
139                 graph_walker, progress)
140
141     def fetch_objects(self, determine_wants, graph_walker, progress):
142         """Fetch the missing objects required for a set of revisions.
143
144         :param determine_wants: Function that takes a dictionary with heads 
145             and returns the list of heads to fetch.
146         :param graph_walker: Object that can iterate over the list of revisions 
147             to fetch and has an "ack" method that will be called to acknowledge 
148             that a revision is present.
149         :param progress: Simple progress function that will be called with 
150             updated progress strings.
151         :return: tuple with number of objects, iterator over objects
152         """
153         return self.object_store.iter_shas(
154             self.find_missing_objects(determine_wants, graph_walker, progress))
155
156     def object_dir(self):
157         return os.path.join(self.controldir(), OBJECTDIR)
158
159     @property
160     def object_store(self):
161         if self._object_store is None:
162             self._object_store = ObjectStore(self.object_dir())
163         return self._object_store
164
165     def pack_dir(self):
166         return os.path.join(self.object_dir(), PACKDIR)
167
168     def _get_ref(self, file):
169         f = open(file, 'rb')
170         try:
171             contents = f.read()
172             if contents.startswith(SYMREF):
173                 ref = contents[len(SYMREF):]
174                 if ref[-1] == '\n':
175                     ref = ref[:-1]
176                 return self.ref(ref)
177             assert len(contents) == 41, 'Invalid ref in %s' % file
178             return contents[:-1]
179         finally:
180             f.close()
181
182     def ref(self, name):
183         """Return the SHA1 a ref is pointing to."""
184         for dir in self.ref_locs:
185             file = os.path.join(self.controldir(), dir, name)
186             if os.path.exists(file):
187                 return self._get_ref(file)
188         packed_refs = self.get_packed_refs()
189         if name in packed_refs:
190             return packed_refs[name]
191
192     def get_refs(self):
193         ret = {}
194         if self.head():
195             ret['HEAD'] = self.head()
196         for dir in ["refs/heads", "refs/tags"]:
197             for name in os.listdir(os.path.join(self.controldir(), dir)):
198                 path = os.path.join(self.controldir(), dir, name)
199                 if os.path.isfile(path):
200                     ret["/".join([dir, name])] = self._get_ref(path)
201         ret.update(self.get_packed_refs())
202         return ret
203
204     def get_packed_refs(self):
205         path = os.path.join(self.controldir(), 'packed-refs')
206         if not os.path.exists(path):
207             return {}
208         ret = {}
209         f = open(path, 'r')
210         try:
211             for entry in read_packed_refs(f):
212                 ret[entry[1]] = entry[0]
213             return ret
214         finally:
215             f.close()
216
217     def set_ref(self, name, value):
218         file = os.path.join(self.controldir(), name)
219         dirpath = os.path.dirname(file)
220         if not os.path.exists(dirpath):
221             os.makedirs(dirpath)
222         f = open(file, 'w')
223         try:
224             f.write(value+"\n")
225         finally:
226             f.close()
227
228     def remove_ref(self, name):
229         file = os.path.join(self.controldir(), name)
230         if os.path.exists(file):
231             os.remove(file)
232
233     def tagdir(self):
234         """Tag directory."""
235         return os.path.join(self.controldir(), REFSDIR, 'tags')
236
237     def get_tags(self):
238         ret = {}
239         for root, dirs, files in os.walk(self.tagdir()):
240             for name in files:
241                 ret[name] = self._get_ref(os.path.join(root, name))
242         return ret
243
244     def heads(self):
245         ret = {}
246         for root, dirs, files in os.walk(os.path.join(self.controldir(), REFSDIR, 'heads')):
247             for name in files:
248                 ret[name] = self._get_ref(os.path.join(root, name))
249         return ret
250
251     def head(self):
252         return self.ref('HEAD')
253
254     def _get_object(self, sha, cls):
255         assert len(sha) in (20, 40)
256         ret = self.get_object(sha)
257         if ret._type != cls._type:
258             if cls is Commit:
259                 raise NotCommitError(ret)
260             elif cls is Blob:
261                 raise NotBlobError(ret)
262             elif cls is Tree:
263                 raise NotTreeError(ret)
264             else:
265                 raise Exception("Type invalid: %r != %r" % (ret._type, cls._type))
266         return ret
267
268     def get_object(self, sha):
269         return self.object_store[sha]
270
271     def get_parents(self, sha):
272         return self.commit(sha).parents
273
274     def commit(self, sha):
275         return self._get_object(sha, Commit)
276
277     def tree(self, sha):
278         return self._get_object(sha, Tree)
279
280     def tag(self, sha):
281         return self._get_object(sha, Tag)
282
283     def get_blob(self, sha):
284         return self._get_object(sha, Blob)
285
286     def revision_history(self, head):
287         """Returns a list of the commits reachable from head.
288
289         Returns a list of commit objects. the first of which will be the commit
290         of head, then following theat will be the parents.
291
292         Raises NotCommitError if any no commits are referenced, including if the
293         head parameter isn't the sha of a commit.
294
295         XXX: work out how to handle merges.
296         """
297         # We build the list backwards, as parents are more likely to be older
298         # than children
299         pending_commits = [head]
300         history = []
301         while pending_commits != []:
302             head = pending_commits.pop(0)
303             try:
304                 commit = self.commit(head)
305             except KeyError:
306                 raise MissingCommitError(head)
307             if commit in history:
308                 continue
309             i = 0
310             for known_commit in history:
311                 if known_commit.commit_time > commit.commit_time:
312                     break
313                 i += 1
314             history.insert(i, commit)
315             parents = commit.parents
316             pending_commits += parents
317         history.reverse()
318         return history
319
320     def __repr__(self):
321         return "<Repo at %r>" % self.path
322
323     @classmethod
324     def init(cls, path, mkdir=True):
325         controldir = os.path.join(path, ".git")
326         os.mkdir(controldir)
327         cls.init_bare(controldir)
328
329     @classmethod
330     def init_bare(cls, path, mkdir=True):
331         for d in [[OBJECTDIR], 
332                   [OBJECTDIR, "info"], 
333                   [OBJECTDIR, "pack"],
334                   ["branches"],
335                   [REFSDIR],
336                   ["refs", "tags"],
337                   ["refs", "heads"],
338                   ["hooks"],
339                   ["info"]]:
340             os.mkdir(os.path.join(path, *d))
341         open(os.path.join(path, 'HEAD'), 'w').write("ref: refs/heads/master\n")
342         open(os.path.join(path, 'description'), 'w').write("Unnamed repository")
343         open(os.path.join(path, 'info', 'excludes'), 'w').write("")
344
345     create = init_bare
346