Merge upstream
[jelmer/dulwich-libgit2.git] / dulwich / repo.py
1 # repo.py -- For dealing wih git repositories.
2 # Copyright (C) 2007 James Westby <jw+debian@jameswestby.net>
3 # Copyright (C) 2008 Jelmer Vernooij <jelmer@samba.org>
4
5 # This program is free software; you can redistribute it and/or
6 # modify it under the terms of the GNU General Public License
7 # as published by the Free Software Foundation; version 2
8 # of the License.
9
10 # This program is distributed in the hope that it will be useful,
11 # but WITHOUT ANY WARRANTY; without even the implied warranty of
12 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 # GNU General Public License for more details.
14
15 # You should have received a copy of the GNU General Public License
16 # along with this program; if not, write to the Free Software
17 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
18 # MA  02110-1301, USA.
19
20 import os
21
22 from commit import Commit
23 from errors import (
24         MissingCommitError, 
25         NotBlobError, 
26         NotCommitError, 
27         NotGitRepository,
28         NotTreeError, 
29         )
30 from object_store import ObjectStore
31 from objects import (
32         ShaFile,
33         Commit,
34         Tree,
35         Blob,
36         )
37 import tempfile
38
39 OBJECTDIR = 'objects'
40 SYMREF = 'ref: '
41
42
43 class Tag(object):
44
45     def __init__(self, name, ref):
46         self.name = name
47         self.ref = ref
48
49
50 class Repo(object):
51
52   ref_locs = ['', 'refs', 'refs/tags', 'refs/heads', 'refs/remotes']
53
54   def __init__(self, root):
55     if os.path.isdir(os.path.join(root, ".git", "objects")):
56       self.bare = False
57       self._controldir = os.path.join(root, ".git")
58     elif os.path.isdir(os.path.join(root, "objects")):
59       self.bare = True
60       self._controldir = root
61     else:
62       raise NotGitRepository(root)
63     self.path = root
64     self.tags = [Tag(name, ref) for name, ref in self.get_tags().items()]
65     self._object_store = None
66
67   def controldir(self):
68     return self._controldir
69
70   def find_missing_objects(self, determine_wants, graph_walker, progress):
71     """Fetch the missing objects required for a set of revisions.
72
73     :param determine_wants: Function that takes a dictionary with heads 
74         and returns the list of heads to fetch.
75     :param graph_walker: Object that can iterate over the list of revisions 
76         to fetch and has an "ack" method that will be called to acknowledge 
77         that a revision is present.
78     :param progress: Simple progress function that will be called with 
79         updated progress strings.
80     """
81     wants = determine_wants(self.get_refs())
82     commits_to_send = set(wants)
83     sha_done = set()
84     ref = graph_walker.next()
85     while ref:
86         sha_done.add(ref)
87         if ref in self.object_store:
88             graph_walker.ack(ref)
89         ref = graph_walker.next()
90     while commits_to_send:
91         sha = commits_to_send.pop()
92         if sha in sha_done:
93             continue
94
95         c = self.commit(sha)
96         assert isinstance(c, Commit)
97         sha_done.add(sha)
98
99         commits_to_send.update([p for p in c.parents if not p in sha_done])
100
101         def parse_tree(tree, sha_done):
102             for mode, name, x in tree.entries():
103                 if not x in sha_done:
104                     try:
105                         t = self.tree(x)
106                         sha_done.add(x)
107                         parse_tree(t, sha_done)
108                     except:
109                         sha_done.add(x)
110
111         treesha = c.tree
112         if treesha not in sha_done:
113             t = self.tree(treesha)
114             sha_done.add(treesha)
115             parse_tree(t, sha_done)
116
117         progress("counting objects: %d\r" % len(sha_done))
118     return sha_done
119
120   def fetch_objects(self, determine_wants, graph_walker, progress):
121     """Fetch the missing objects required for a set of revisions.
122
123     :param determine_wants: Function that takes a dictionary with heads 
124         and returns the list of heads to fetch.
125     :param graph_walker: Object that can iterate over the list of revisions 
126         to fetch and has an "ack" method that will be called to acknowledge 
127         that a revision is present.
128     :param progress: Simple progress function that will be called with 
129         updated progress strings.
130     """
131     shas = self.find_missing_objects(determine_wants, graph_walker, progress)
132     for sha in shas:
133         yield self.get_object(sha)
134
135   def object_dir(self):
136     return os.path.join(self.controldir(), OBJECTDIR)
137
138   @property
139   def object_store(self):
140     if self._object_store is None:
141         self._object_store = ObjectStore(self.object_dir())
142     return self._object_store
143
144   def pack_dir(self):
145     return os.path.join(self.object_dir(), PACKDIR)
146
147   def _get_ref(self, file):
148     f = open(file, 'rb')
149     try:
150       contents = f.read()
151       if contents.startswith(SYMREF):
152         ref = contents[len(SYMREF):]
153         if ref[-1] == '\n':
154           ref = ref[:-1]
155         return self.ref(ref)
156       assert len(contents) == 41, 'Invalid ref in %s' % file
157       return contents[:-1]
158     finally:
159       f.close()
160
161   def ref(self, name):
162     for dir in self.ref_locs:
163       file = os.path.join(self.controldir(), dir, name)
164       if os.path.exists(file):
165         return self._get_ref(file)
166
167   def get_refs(self):
168     ret = {}
169     if self.head():
170         ret['HEAD'] = self.head()
171     for dir in ["refs/heads", "refs/tags"]:
172         for name in os.listdir(os.path.join(self.controldir(), dir)):
173           path = os.path.join(self.controldir(), dir, name)
174           if os.path.isfile(path):
175             ret["/".join([dir, name])] = self._get_ref(path)
176     return ret
177
178   def set_ref(self, name, value):
179     file = os.path.join(self.controldir(), name)
180     open(file, 'w').write(value+"\n")
181
182   def remove_ref(self, name):
183     file = os.path.join(self.controldir(), name)
184     if os.path.exists(file):
185       os.remove(file)
186       return
187
188   def get_tags(self):
189     ret = {}
190     for root, dirs, files in os.walk(os.path.join(self.controldir(), 'refs', 'tags')):
191       for name in files:
192         ret[name] = self._get_ref(os.path.join(root, name))
193     return ret
194
195   def heads(self):
196     ret = {}
197     for root, dirs, files in os.walk(os.path.join(self.controldir(), 'refs', 'heads')):
198       for name in files:
199         ret[name] = self._get_ref(os.path.join(root, name))
200     return ret
201
202   def head(self):
203     return self.ref('HEAD')
204
205   def _get_object(self, sha, cls):
206     assert len(sha) in (20, 40)
207     ret = self.get_object(sha)
208     if ret._type != cls._type:
209         if cls is Commit:
210             raise NotCommitError(ret)
211         elif cls is Blob:
212             raise NotBlobError(ret)
213         elif cls is Tree:
214             raise NotTreeError(ret)
215         else:
216             raise Exception("Type invalid: %r != %r" % (ret._type, cls._type))
217     return ret
218
219   def get_object(self, sha):
220     return self.object_store[sha]
221
222   def get_parents(self, sha):
223     return self.commit(sha).parents
224
225   def commit(self, sha):
226     return self._get_object(sha, Commit)
227
228   def tree(self, sha):
229     return self._get_object(sha, Tree)
230
231   def get_blob(self, sha):
232     return self._get_object(sha, Blob)
233
234   def revision_history(self, head):
235     """Returns a list of the commits reachable from head.
236
237     Returns a list of commit objects. the first of which will be the commit
238     of head, then following theat will be the parents.
239
240     Raises NotCommitError if any no commits are referenced, including if the
241     head parameter isn't the sha of a commit.
242
243     XXX: work out how to handle merges.
244     """
245     # We build the list backwards, as parents are more likely to be older
246     # than children
247     pending_commits = [head]
248     history = []
249     while pending_commits != []:
250       head = pending_commits.pop(0)
251       try:
252           commit = self.commit(head)
253       except KeyError:
254         raise MissingCommitError(head)
255       if commit in history:
256         continue
257       i = 0
258       for known_commit in history:
259         if known_commit.commit_time > commit.commit_time:
260           break
261         i += 1
262       history.insert(i, commit)
263       parents = commit.parents
264       pending_commits += parents
265     history.reverse()
266     return history
267
268   def __repr__(self):
269       return "<Repo at %r>" % self.path
270
271   @classmethod
272   def init_bare(cls, path, mkdir=True):
273       for d in [["objects"], 
274                 ["objects", "info"], 
275                 ["objects", "pack"],
276                 ["branches"],
277                 ["refs"],
278                 ["refs", "tags"],
279                 ["refs", "heads"],
280                 ["hooks"],
281                 ["info"]]:
282           os.mkdir(os.path.join(path, *d))
283       open(os.path.join(path, 'HEAD'), 'w').write("ref: refs/heads/master\n")
284       open(os.path.join(path, 'description'), 'w').write("Unnamed repository")
285       open(os.path.join(path, 'info', 'excludes'), 'w').write("")
286
287   create = init_bare
288
289
290