Fix C implementation of parse_tree to return a dictionary.
[jelmer/dulwich-libgit2.git] / dulwich / objects.py
1 # objects.py -- Access to base git objects
2 # Copyright (C) 2007 James Westby <jw+debian@jameswestby.net>
3 # Copyright (C) 2008 Jelmer Vernooij <jelmer@samba.org>
4
5 # This program is free software; you can redistribute it and/or
6 # modify it under the terms of the GNU General Public License
7 # as published by the Free Software Foundation; version 2
8 # of the License or (at your option) a later version of the License.
9
10 # This program is distributed in the hope that it will be useful,
11 # but WITHOUT ANY WARRANTY; without even the implied warranty of
12 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 # GNU General Public License for more details.
14
15 # You should have received a copy of the GNU General Public License
16 # along with this program; if not, write to the Free Software
17 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
18 # MA  02110-1301, USA.
19
20
21 """Access to base git objects."""
22
23
24 import mmap
25 import os
26 import sha
27 import zlib
28
29 from dulwich.errors import (
30     NotBlobError,
31     NotCommitError,
32     NotTreeError,
33     )
34
35 BLOB_ID = "blob"
36 TAG_ID = "tag"
37 TREE_ID = "tree"
38 COMMIT_ID = "commit"
39 PARENT_ID = "parent"
40 AUTHOR_ID = "author"
41 COMMITTER_ID = "committer"
42 OBJECT_ID = "object"
43 TYPE_ID = "type"
44 TAGGER_ID = "tagger"
45
46 def _decompress(string):
47     dcomp = zlib.decompressobj()
48     dcomped = dcomp.decompress(string)
49     dcomped += dcomp.flush()
50     return dcomped
51
52
53 def sha_to_hex(sha):
54     """Takes a string and returns the hex of the sha within"""
55     hexsha = "".join(["%02x" % ord(c) for c in sha])
56     assert len(hexsha) == 40, "Incorrect length of sha1 string: %d" % hexsha
57     return hexsha
58
59
60 def hex_to_sha(hex):
61     """Takes a hex sha and returns a binary sha"""
62     assert len(hex) == 40, "Incorrent length of hexsha: %s" % hex
63     return ''.join([chr(int(hex[i:i+2], 16)) for i in xrange(0, len(hex), 2)])
64
65
66 class ShaFile(object):
67     """A git SHA file."""
68   
69     @classmethod
70     def _parse_legacy_object(cls, map):
71         """Parse a legacy object, creating it and setting object._text"""
72         text = _decompress(map)
73         object = None
74         for posstype in type_map.keys():
75             if text.startswith(posstype):
76                 object = type_map[posstype]()
77                 text = text[len(posstype):]
78                 break
79         assert object is not None, "%s is not a known object type" % text[:9]
80         assert text[0] == ' ', "%s is not a space" % text[0]
81         text = text[1:]
82         size = 0
83         i = 0
84         while text[0] >= '0' and text[0] <= '9':
85             if i > 0 and size == 0:
86                 assert False, "Size is not in canonical format"
87             size = (size * 10) + int(text[0])
88             text = text[1:]
89             i += 1
90         object._size = size
91         assert text[0] == "\0", "Size not followed by null"
92         text = text[1:]
93         object._text = text
94         return object
95
96     def as_legacy_object(self):
97         return zlib.compress("%s %d\0%s" % (self._type, len(self._text), self._text))
98   
99     def as_raw_string(self):
100         return self._num_type, self._text
101   
102     @classmethod
103     def _parse_object(cls, map):
104         """Parse a new style object , creating it and setting object._text"""
105         used = 0
106         byte = ord(map[used])
107         used += 1
108         num_type = (byte >> 4) & 7
109         try:
110             object = num_type_map[num_type]()
111         except KeyError:
112             raise AssertionError("Not a known type: %d" % num_type)
113         while (byte & 0x80) != 0:
114             byte = ord(map[used])
115             used += 1
116         raw = map[used:]
117         object._text = _decompress(raw)
118         return object
119   
120     @classmethod
121     def _parse_file(cls, map):
122         word = (ord(map[0]) << 8) + ord(map[1])
123         if ord(map[0]) == 0x78 and (word % 31) == 0:
124             return cls._parse_legacy_object(map)
125         else:
126             return cls._parse_object(map)
127   
128     def __init__(self):
129         """Don't call this directly"""
130   
131     def _parse_text(self):
132         """For subclasses to do initialisation time parsing"""
133   
134     @classmethod
135     def from_file(cls, filename):
136         """Get the contents of a SHA file on disk"""
137         size = os.path.getsize(filename)
138         f = open(filename, 'rb')
139         try:
140             map = mmap.mmap(f.fileno(), size, access=mmap.ACCESS_READ)
141             shafile = cls._parse_file(map)
142             shafile._parse_text()
143             return shafile
144         finally:
145             f.close()
146   
147     @classmethod
148     def from_raw_string(cls, type, string):
149         """Creates an object of the indicated type from the raw string given.
150     
151         Type is the numeric type of an object. String is the raw uncompressed
152         contents.
153         """
154         real_class = num_type_map[type]
155         obj = real_class()
156         obj._num_type = type
157         obj._text = string
158         obj._parse_text()
159         return obj
160   
161     def _header(self):
162         return "%s %lu\0" % (self._type, len(self._text))
163   
164     def sha(self):
165         """The SHA1 object that is the name of this object."""
166         ressha = sha.new()
167         ressha.update(self._header())
168         ressha.update(self._text)
169         return ressha
170   
171     @property
172     def id(self):
173         return self.sha().hexdigest()
174   
175     @property
176     def type(self):
177         return self._num_type
178   
179     def __repr__(self):
180         return "<%s %s>" % (self.__class__.__name__, self.id)
181   
182     def __eq__(self, other):
183         """Return true id the sha of the two objects match.
184   
185         The __le__ etc methods aren't overriden as they make no sense,
186         certainly at this level.
187         """
188         return self.sha().digest() == other.sha().digest()
189
190
191 class Blob(ShaFile):
192     """A Git Blob object."""
193
194     _type = BLOB_ID
195     _num_type = 3
196
197     @property
198     def data(self):
199         """The text contained within the blob object."""
200         return self._text
201
202     @classmethod
203     def from_file(cls, filename):
204         blob = ShaFile.from_file(filename)
205         if blob._type != cls._type:
206             raise NotBlobError(filename)
207         return blob
208
209     @classmethod
210     def from_string(cls, string):
211         """Create a blob from a string."""
212         shafile = cls()
213         shafile._text = string
214         return shafile
215
216
217 class Tag(ShaFile):
218     """A Git Tag object."""
219
220     _type = TAG_ID
221     _num_type = 4
222
223     @classmethod
224     def from_file(cls, filename):
225         blob = ShaFile.from_file(filename)
226         if blob._type != cls._type:
227             raise NotBlobError(filename)
228         return blob
229
230     @classmethod
231     def from_string(cls, string):
232         """Create a blob from a string."""
233         shafile = cls()
234         shafile._text = string
235         return shafile
236
237     def _parse_text(self):
238         """Grab the metadata attached to the tag"""
239         text = self._text
240         count = 0
241         assert text.startswith(OBJECT_ID), "Invalid tag object, " \
242             "must start with %s" % OBJECT_ID
243         count += len(OBJECT_ID)
244         assert text[count] == ' ', "Invalid tag object, " \
245             "%s must be followed by space not %s" % (OBJECT_ID, text[count])
246         count += 1
247         self._object_sha = text[count:count+40]
248         count += 40
249         assert text[count] == '\n', "Invalid tag object, " \
250             "%s sha must be followed by newline" % OBJECT_ID
251         count += 1
252         assert text[count:].startswith(TYPE_ID), "Invalid tag object, " \
253             "%s sha must be followed by %s" % (OBJECT_ID, TYPE_ID)
254         count += len(TYPE_ID)
255         assert text[count] == ' ', "Invalid tag object, " \
256             "%s must be followed by space not %s" % (TAG_ID, text[count])
257         count += 1
258         self._object_type = ""
259         while text[count] != '\n':
260             self._object_type += text[count]
261             count += 1
262         count += 1
263         assert self._object_type in (COMMIT_ID, BLOB_ID, TREE_ID, TAG_ID), "Invalid tag object, " \
264             "unexpected object type %s" % self._object_type
265         self._object_type = type_map[self._object_type]
266
267         assert text[count:].startswith(TAG_ID), "Invalid tag object, " \
268             "object type must be followed by %s" % (TAG_ID)
269         count += len(TAG_ID)
270         assert text[count] == ' ', "Invalid tag object, " \
271             "%s must be followed by space not %s" % (TAG_ID, text[count])
272         count += 1
273         self._name = ""
274         while text[count] != '\n':
275             self._name += text[count]
276             count += 1
277         count += 1
278
279         assert text[count:].startswith(TAGGER_ID), "Invalid tag object, " \
280             "%s must be followed by %s" % (TAG_ID, TAGGER_ID)
281         count += len(TAGGER_ID)
282         assert text[count] == ' ', "Invalid tag object, " \
283             "%s must be followed by space not %s" % (TAGGER_ID, text[count])
284         count += 1
285         self._tagger = ""
286         while text[count] != '>':
287             assert text[count] != '\n', "Malformed tagger information"
288             self._tagger += text[count]
289             count += 1
290         self._tagger += text[count]
291         count += 1
292         assert text[count] == ' ', "Invalid tag object, " \
293             "tagger information must be followed by space not %s" % text[count]
294         count += 1
295         self._tag_time = int(text[count:count+10])
296         while text[count] != '\n':
297             count += 1
298         count += 1
299         assert text[count] == '\n', "There must be a new line after the headers"
300         count += 1
301         self._message = text[count:]
302
303     @property
304     def object(self):
305         """Returns the object pointed by this tag, represented as a tuple(type, sha)"""
306         return (self._object_type, self._object_sha)
307
308     @property
309     def name(self):
310         """Returns the name of this tag"""
311         return self._name
312
313     @property
314     def tagger(self):
315         """Returns the name of the person who created this tag"""
316         return self._tagger
317
318     @property
319     def tag_time(self):
320         """Returns the creation timestamp of the tag.
321
322         Returns it as the number of seconds since the epoch"""
323         return self._tag_time
324
325     @property
326     def message(self):
327         """Returns the message attached to this tag"""
328         return self._message
329
330
331 def parse_tree(text):
332     ret = []
333     count = 0
334     while count < len(text):
335         mode = 0
336         chr = text[count]
337         while chr != ' ':
338             assert chr >= '0' and chr <= '7', "%s is not a valid mode char" % chr
339             mode = (mode << 3) + (ord(chr) - ord('0'))
340             count += 1
341             chr = text[count]
342         count += 1
343         chr = text[count]
344         name = ''
345         while chr != '\0':
346             name += chr
347             count += 1
348             chr = text[count]
349         count += 1
350         chr = text[count]
351         sha = text[count:count+20]
352         hexsha = sha_to_hex(sha)
353         ret.append((mode, name, hexsha))
354         count = count + 20
355     return ret
356
357
358 class Tree(ShaFile):
359     """A Git tree object"""
360
361     _type = TREE_ID
362     _num_type = 2
363
364     def __init__(self):
365         self._entries = {}
366
367     @classmethod
368     def from_file(cls, filename):
369         tree = ShaFile.from_file(filename)
370         if tree._type != cls._type:
371             raise NotTreeError(filename)
372         return tree
373
374     def __getitem__(self, name):
375         return self._entries[name]
376
377     def __setitem__(self, name, value):
378         assert isinstance(value, tuple)
379         assert len(value) == 2
380         self._entries[name] = value
381
382     def __delitem__(self, name):
383         del self._entries[name]
384
385     def add(self, mode, name, hexsha):
386         self._entries[name] = mode, hexsha
387
388     def entries(self):
389         """Return a list of tuples describing the tree entries"""
390         # The order of this is different from iteritems() for historical reasons
391         return [(mode, name, hexsha) for (name, mode, hexsha) in self.iteritems()]
392
393     def iteritems(self):
394         for name in sorted(self._entries.keys()):
395             yield name, self._entries[name][0], self._entries[name][1]
396
397     def _parse_text(self):
398         """Grab the entries in the tree"""
399         self._entries = parse_tree(self._text)
400
401     def serialize(self):
402         self._text = ""
403         for name, mode, hexsha in self.iteritems():
404             self._text += "%04o %s\0%s" % (mode, name, hex_to_sha(hexsha))
405
406
407 class Commit(ShaFile):
408     """A git commit object"""
409
410     _type = COMMIT_ID
411     _num_type = 1
412
413     def __init__(self):
414         self._parents = []
415
416     @classmethod
417     def from_file(cls, filename):
418         commit = ShaFile.from_file(filename)
419         if commit._type != cls._type:
420             raise NotCommitError(filename)
421         return commit
422
423     def _parse_text(self):
424         text = self._text
425         count = 0
426         assert text.startswith(TREE_ID), "Invalid commit object, " \
427              "must start with %s" % TREE_ID
428         count += len(TREE_ID)
429         assert text[count] == ' ', "Invalid commit object, " \
430              "%s must be followed by space not %s" % (TREE_ID, text[count])
431         count += 1
432         self._tree = text[count:count+40]
433         count = count + 40
434         assert text[count] == "\n", "Invalid commit object, " \
435              "tree sha must be followed by newline"
436         count += 1
437         self._parents = []
438         while text[count:].startswith(PARENT_ID):
439             count += len(PARENT_ID)
440             assert text[count] == ' ', "Invalid commit object, " \
441                  "%s must be followed by space not %s" % (PARENT_ID, text[count])
442             count += 1
443             self._parents.append(text[count:count+40])
444             count += 40
445             assert text[count] == "\n", "Invalid commit object, " \
446                  "parent sha must be followed by newline"
447             count += 1
448         self._author = None
449         if text[count:].startswith(AUTHOR_ID):
450             count += len(AUTHOR_ID)
451             assert text[count] == ' ', "Invalid commit object, " \
452                  "%s must be followed by space not %s" % (AUTHOR_ID, text[count])
453             count += 1
454             self._author = ''
455             while text[count] != '>':
456                 assert text[count] != '\n', "Malformed author information"
457                 self._author += text[count]
458                 count += 1
459             self._author += text[count]
460             count += 1
461             assert text[count] == ' ', "Invalid commit object, " \
462                  "author information must be followed by space not %s" % text[count]
463             count += 1
464             self._author_time = int(text[count:count+10])
465             while text[count] != ' ':
466                 assert text[count] != '\n', "Malformed author information"
467                 count += 1
468             self._author_timezone = int(text[count:count+6])
469             count += 1
470             while text[count] != '\n':
471                 count += 1
472             count += 1
473         self._committer = None
474         if text[count:].startswith(COMMITTER_ID):
475             count += len(COMMITTER_ID)
476             assert text[count] == ' ', "Invalid commit object, " \
477                  "%s must be followed by space not %s" % (COMMITTER_ID, text[count])
478             count += 1
479             self._committer = ''
480             while text[count] != '>':
481                 assert text[count] != '\n', "Malformed committer information"
482                 self._committer += text[count]
483                 count += 1
484             self._committer += text[count]
485             count += 1
486             assert text[count] == ' ', "Invalid commit object, " \
487                  "commiter information must be followed by space not %s" % text[count]
488             count += 1
489             self._commit_time = int(text[count:count+10])
490             while text[count] != ' ':
491                 assert text[count] != '\n', "Malformed committer information"
492                 count += 1
493             self._commit_timezone = int(text[count:count+6])
494             count += 1
495             while text[count] != '\n':
496                 count += 1
497             count += 1
498         assert text[count] == '\n', "There must be a new line after the headers"
499         count += 1
500         # XXX: There can be an encoding field.
501         self._message = text[count:]
502
503     def serialize(self):
504         self._text = ""
505         self._text += "%s %s\n" % (TREE_ID, self._tree)
506         for p in self._parents:
507             self._text += "%s %s\n" % (PARENT_ID, p)
508         self._text += "%s %s %s %+05d\n" % (AUTHOR_ID, self._author, str(self._author_time), self._author_timezone)
509         self._text += "%s %s %s %+05d\n" % (COMMITTER_ID, self._committer, str(self._commit_time), self._commit_timezone)
510         self._text += "\n" # There must be a new line after the headers
511         self._text += self._message
512
513     @property
514     def tree(self):
515         """Returns the tree that is the state of this commit"""
516         return self._tree
517
518     @property
519     def parents(self):
520         """Return a list of parents of this commit."""
521         return self._parents
522
523     @property
524     def author(self):
525         """Returns the name of the author of the commit"""
526         return self._author
527
528     @property
529     def committer(self):
530         """Returns the name of the committer of the commit"""
531         return self._committer
532
533     @property
534     def message(self):
535         """Returns the commit message"""
536         return self._message
537
538     @property
539     def commit_time(self):
540         """Returns the timestamp of the commit.
541         
542         Returns it as the number of seconds since the epoch.
543         """
544         return self._commit_time
545
546     @property
547     def commit_timezone(self):
548         """Returns the zone the commit time is in
549         """
550         return self._commit_timezone
551
552     @property
553     def author_time(self):
554         """Returns the timestamp the commit was written.
555         
556         Returns it as the number of seconds since the epoch.
557         """
558         return self._author_time
559
560     @property
561     def author_timezone(self):
562         """Returns the zone the author time is in
563         """
564         return self._author_timezone
565
566
567 type_map = {
568     BLOB_ID : Blob,
569     TREE_ID : Tree,
570     COMMIT_ID : Commit,
571     TAG_ID: Tag,
572 }
573
574 num_type_map = {
575     0: None,
576     1: Commit,
577     2: Tree,
578     3: Blob,
579     4: Tag,
580     # 5 Is reserved for further expansion
581 }
582
583 try:
584     # Try to import C versions
585     from dulwich._objects import hex_to_sha, sha_to_hex, parse_tree
586 except ImportError:
587     pass
588