Use separate function for tree parsing, allow C extension for tree parsing.
[jelmer/dulwich-libgit2.git] / dulwich / objects.py
1 # objects.py -- Access to base git objects
2 # Copyright (C) 2007 James Westby <jw+debian@jameswestby.net>
3 # Copyright (C) 2008 Jelmer Vernooij <jelmer@samba.org>
4
5 # This program is free software; you can redistribute it and/or
6 # modify it under the terms of the GNU General Public License
7 # as published by the Free Software Foundation; version 2
8 # of the License or (at your option) a later version of the License.
9
10 # This program is distributed in the hope that it will be useful,
11 # but WITHOUT ANY WARRANTY; without even the implied warranty of
12 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 # GNU General Public License for more details.
14
15 # You should have received a copy of the GNU General Public License
16 # along with this program; if not, write to the Free Software
17 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
18 # MA  02110-1301, USA.
19
20
21 """Access to base git objects."""
22
23
24 import mmap
25 import os
26 import sha
27 import zlib
28
29 from dulwich.errors import (
30     NotBlobError,
31     NotCommitError,
32     NotTreeError,
33     )
34
35 BLOB_ID = "blob"
36 TAG_ID = "tag"
37 TREE_ID = "tree"
38 COMMIT_ID = "commit"
39 PARENT_ID = "parent"
40 AUTHOR_ID = "author"
41 COMMITTER_ID = "committer"
42 OBJECT_ID = "object"
43 TYPE_ID = "type"
44 TAGGER_ID = "tagger"
45
46 def _decompress(string):
47     dcomp = zlib.decompressobj()
48     dcomped = dcomp.decompress(string)
49     dcomped += dcomp.flush()
50     return dcomped
51
52
53 def sha_to_hex(sha):
54     """Takes a string and returns the hex of the sha within"""
55     hexsha = "".join(["%02x" % ord(c) for c in sha])
56     assert len(hexsha) == 40, "Incorrect length of sha1 string: %d" % hexsha
57     return hexsha
58
59
60 def hex_to_sha(hex):
61     """Takes a hex sha and returns a binary sha"""
62     assert len(hex) == 40, "Incorrent length of hexsha: %s" % hex
63     return ''.join([chr(int(hex[i:i+2], 16)) for i in xrange(0, len(hex), 2)])
64
65
66 class ShaFile(object):
67     """A git SHA file."""
68   
69     @classmethod
70     def _parse_legacy_object(cls, map):
71         """Parse a legacy object, creating it and setting object._text"""
72         text = _decompress(map)
73         object = None
74         for posstype in type_map.keys():
75             if text.startswith(posstype):
76                 object = type_map[posstype]()
77                 text = text[len(posstype):]
78                 break
79         assert object is not None, "%s is not a known object type" % text[:9]
80         assert text[0] == ' ', "%s is not a space" % text[0]
81         text = text[1:]
82         size = 0
83         i = 0
84         while text[0] >= '0' and text[0] <= '9':
85             if i > 0 and size == 0:
86                 assert False, "Size is not in canonical format"
87             size = (size * 10) + int(text[0])
88             text = text[1:]
89             i += 1
90         object._size = size
91         assert text[0] == "\0", "Size not followed by null"
92         text = text[1:]
93         object._text = text
94         return object
95
96     def as_legacy_object(self):
97         return zlib.compress("%s %d\0%s" % (self._type, len(self._text), self._text))
98   
99     def as_raw_string(self):
100         return self._num_type, self._text
101   
102     @classmethod
103     def _parse_object(cls, map):
104         """Parse a new style object , creating it and setting object._text"""
105         used = 0
106         byte = ord(map[used])
107         used += 1
108         num_type = (byte >> 4) & 7
109         try:
110             object = num_type_map[num_type]()
111         except KeyError:
112             raise AssertionError("Not a known type: %d" % num_type)
113         while (byte & 0x80) != 0:
114             byte = ord(map[used])
115             used += 1
116         raw = map[used:]
117         object._text = _decompress(raw)
118         return object
119   
120     @classmethod
121     def _parse_file(cls, map):
122         word = (ord(map[0]) << 8) + ord(map[1])
123         if ord(map[0]) == 0x78 and (word % 31) == 0:
124             return cls._parse_legacy_object(map)
125         else:
126             return cls._parse_object(map)
127   
128     def __init__(self):
129         """Don't call this directly"""
130   
131     def _parse_text(self):
132         """For subclasses to do initialisation time parsing"""
133   
134     @classmethod
135     def from_file(cls, filename):
136         """Get the contents of a SHA file on disk"""
137         size = os.path.getsize(filename)
138         f = open(filename, 'rb')
139         try:
140             map = mmap.mmap(f.fileno(), size, access=mmap.ACCESS_READ)
141             shafile = cls._parse_file(map)
142             shafile._parse_text()
143             return shafile
144         finally:
145             f.close()
146   
147     @classmethod
148     def from_raw_string(cls, type, string):
149         """Creates an object of the indicated type from the raw string given.
150     
151         Type is the numeric type of an object. String is the raw uncompressed
152         contents.
153         """
154         real_class = num_type_map[type]
155         obj = real_class()
156         obj._num_type = type
157         obj._text = string
158         obj._parse_text()
159         return obj
160   
161     def _header(self):
162         return "%s %lu\0" % (self._type, len(self._text))
163   
164     def sha(self):
165         """The SHA1 object that is the name of this object."""
166         ressha = sha.new()
167         ressha.update(self._header())
168         ressha.update(self._text)
169         return ressha
170   
171     @property
172     def id(self):
173         return self.sha().hexdigest()
174   
175     @property
176     def type(self):
177         return self._num_type
178   
179     def __repr__(self):
180         return "<%s %s>" % (self.__class__.__name__, self.id)
181   
182     def __eq__(self, other):
183         """Return true id the sha of the two objects match.
184   
185         The __le__ etc methods aren't overriden as they make no sense,
186         certainly at this level.
187         """
188         return self.sha().digest() == other.sha().digest()
189
190
191 class Blob(ShaFile):
192     """A Git Blob object."""
193
194     _type = BLOB_ID
195     _num_type = 3
196
197     @property
198     def data(self):
199         """The text contained within the blob object."""
200         return self._text
201
202     @classmethod
203     def from_file(cls, filename):
204         blob = ShaFile.from_file(filename)
205         if blob._type != cls._type:
206             raise NotBlobError(filename)
207         return blob
208
209     @classmethod
210     def from_string(cls, string):
211         """Create a blob from a string."""
212         shafile = cls()
213         shafile._text = string
214         return shafile
215
216
217 class Tag(ShaFile):
218     """A Git Tag object."""
219
220     _type = TAG_ID
221     _num_type = 4
222
223     @classmethod
224     def from_file(cls, filename):
225         blob = ShaFile.from_file(filename)
226         if blob._type != cls._type:
227             raise NotBlobError(filename)
228         return blob
229
230     @classmethod
231     def from_string(cls, string):
232         """Create a blob from a string."""
233         shafile = cls()
234         shafile._text = string
235         return shafile
236
237     def _parse_text(self):
238         """Grab the metadata attached to the tag"""
239         text = self._text
240         count = 0
241         assert text.startswith(OBJECT_ID), "Invalid tag object, " \
242             "must start with %s" % OBJECT_ID
243         count += len(OBJECT_ID)
244         assert text[count] == ' ', "Invalid tag object, " \
245             "%s must be followed by space not %s" % (OBJECT_ID, text[count])
246         count += 1
247         self._object_sha = text[count:count+40]
248         count += 40
249         assert text[count] == '\n', "Invalid tag object, " \
250             "%s sha must be followed by newline" % OBJECT_ID
251         count += 1
252         assert text[count:].startswith(TYPE_ID), "Invalid tag object, " \
253             "%s sha must be followed by %s" % (OBJECT_ID, TYPE_ID)
254         count += len(TYPE_ID)
255         assert text[count] == ' ', "Invalid tag object, " \
256             "%s must be followed by space not %s" % (TAG_ID, text[count])
257         count += 1
258         self._object_type = ""
259         while text[count] != '\n':
260             self._object_type += text[count]
261             count += 1
262         count += 1
263         assert self._object_type in (COMMIT_ID, BLOB_ID, TREE_ID, TAG_ID), "Invalid tag object, " \
264             "unexpected object type %s" % self._object_type
265         self._object_type = type_map[self._object_type]
266
267         assert text[count:].startswith(TAG_ID), "Invalid tag object, " \
268             "object type must be followed by %s" % (TAG_ID)
269         count += len(TAG_ID)
270         assert text[count] == ' ', "Invalid tag object, " \
271             "%s must be followed by space not %s" % (TAG_ID, text[count])
272         count += 1
273         self._name = ""
274         while text[count] != '\n':
275             self._name += text[count]
276             count += 1
277         count += 1
278
279         assert text[count:].startswith(TAGGER_ID), "Invalid tag object, " \
280             "%s must be followed by %s" % (TAG_ID, TAGGER_ID)
281         count += len(TAGGER_ID)
282         assert text[count] == ' ', "Invalid tag object, " \
283             "%s must be followed by space not %s" % (TAGGER_ID, text[count])
284         count += 1
285         self._tagger = ""
286         while text[count] != '>':
287             assert text[count] != '\n', "Malformed tagger information"
288             self._tagger += text[count]
289             count += 1
290         self._tagger += text[count]
291         count += 1
292         assert text[count] == ' ', "Invalid tag object, " \
293             "tagger information must be followed by space not %s" % text[count]
294         count += 1
295         self._tag_time = int(text[count:count+10])
296         while text[count] != '\n':
297             count += 1
298         count += 1
299         assert text[count] == '\n', "There must be a new line after the headers"
300         count += 1
301         self._message = text[count:]
302
303     @property
304     def object(self):
305         """Returns the object pointed by this tag, represented as a tuple(type, sha)"""
306         return (self._object_type, self._object_sha)
307
308     @property
309     def name(self):
310         """Returns the name of this tag"""
311         return self._name
312
313     @property
314     def tagger(self):
315         """Returns the name of the person who created this tag"""
316         return self._tagger
317
318     @property
319     def tag_time(self):
320         """Returns the creation timestamp of the tag.
321
322         Returns it as the number of seconds since the epoch"""
323         return self._tag_time
324
325     @property
326     def message(self):
327         """Returns the message attached to this tag"""
328         return self._message
329
330
331 def parse_tree(text):
332     ret = []
333     count = 0
334     while count < len(text):
335         mode = 0
336         chr = text[count]
337         while chr != ' ':
338             assert chr >= '0' and chr <= '7', "%s is not a valid mode char" % chr
339             mode = (mode << 3) + (ord(chr) - ord('0'))
340             count += 1
341             chr = text[count]
342         count += 1
343         chr = text[count]
344         name = ''
345         while chr != '\0':
346             name += chr
347             count += 1
348             chr = text[count]
349         count += 1
350         chr = text[count]
351         sha = text[count:count+20]
352         hexsha = sha_to_hex(sha)
353         ret.append((mode, name, hexsha))
354         count = count + 20
355     return ret
356
357
358 class Tree(ShaFile):
359     """A Git tree object"""
360
361     _type = TREE_ID
362     _num_type = 2
363
364     def __init__(self):
365         self._entries = {}
366
367     @classmethod
368     def from_file(cls, filename):
369         tree = ShaFile.from_file(filename)
370         if tree._type != cls._type:
371             raise NotTreeError(filename)
372         return tree
373
374     def __getitem__(self, name):
375         return self._entries[name]
376
377     def __setitem__(self, name, value):
378         assert isinstance(value, tuple)
379         assert len(value) == 2
380         self._entries[name] = value
381
382     def __delitem__(self, name):
383         del self._entries[name]
384
385     def add(self, mode, name, hexsha):
386         self._entries[name] = mode, hexsha
387
388     def entries(self):
389         """Return a list of tuples describing the tree entries"""
390         return [(mode, name, hexsha) for (name, (mode, hexsha)) in self._entries.iteritems()]
391
392     def iteritems(self):
393         for name in sorted(self._entries.keys()):
394             yield name, self._entries[name][0], self._entries[name][1]
395
396     def _parse_text(self):
397         """Grab the entries in the tree"""
398         self._entries = parse_tree(self._text)
399
400     def serialize(self):
401         self._text = ""
402         for name, mode, hexsha in self.iteritems():
403             self._text += "%04o %s\0%s" % (mode, name, hex_to_sha(hexsha))
404
405
406 class Commit(ShaFile):
407     """A git commit object"""
408
409     _type = COMMIT_ID
410     _num_type = 1
411
412     def __init__(self):
413         self._parents = []
414
415     @classmethod
416     def from_file(cls, filename):
417         commit = ShaFile.from_file(filename)
418         if commit._type != cls._type:
419             raise NotCommitError(filename)
420         return commit
421
422     def _parse_text(self):
423         text = self._text
424         count = 0
425         assert text.startswith(TREE_ID), "Invalid commit object, " \
426              "must start with %s" % TREE_ID
427         count += len(TREE_ID)
428         assert text[count] == ' ', "Invalid commit object, " \
429              "%s must be followed by space not %s" % (TREE_ID, text[count])
430         count += 1
431         self._tree = text[count:count+40]
432         count = count + 40
433         assert text[count] == "\n", "Invalid commit object, " \
434              "tree sha must be followed by newline"
435         count += 1
436         self._parents = []
437         while text[count:].startswith(PARENT_ID):
438             count += len(PARENT_ID)
439             assert text[count] == ' ', "Invalid commit object, " \
440                  "%s must be followed by space not %s" % (PARENT_ID, text[count])
441             count += 1
442             self._parents.append(text[count:count+40])
443             count += 40
444             assert text[count] == "\n", "Invalid commit object, " \
445                  "parent sha must be followed by newline"
446             count += 1
447         self._author = None
448         if text[count:].startswith(AUTHOR_ID):
449             count += len(AUTHOR_ID)
450             assert text[count] == ' ', "Invalid commit object, " \
451                  "%s must be followed by space not %s" % (AUTHOR_ID, text[count])
452             count += 1
453             self._author = ''
454             while text[count] != '>':
455                 assert text[count] != '\n', "Malformed author information"
456                 self._author += text[count]
457                 count += 1
458             self._author += text[count]
459             count += 1
460             assert text[count] == ' ', "Invalid commit object, " \
461                  "author information must be followed by space not %s" % text[count]
462             count += 1
463             self._author_time = int(text[count:count+10])
464             while text[count] != ' ':
465                 assert text[count] != '\n', "Malformed author information"
466                 count += 1
467             self._author_timezone = int(text[count:count+6])
468             count += 1
469             while text[count] != '\n':
470                 count += 1
471             count += 1
472         self._committer = None
473         if text[count:].startswith(COMMITTER_ID):
474             count += len(COMMITTER_ID)
475             assert text[count] == ' ', "Invalid commit object, " \
476                  "%s must be followed by space not %s" % (COMMITTER_ID, text[count])
477             count += 1
478             self._committer = ''
479             while text[count] != '>':
480                 assert text[count] != '\n', "Malformed committer information"
481                 self._committer += text[count]
482                 count += 1
483             self._committer += text[count]
484             count += 1
485             assert text[count] == ' ', "Invalid commit object, " \
486                  "commiter information must be followed by space not %s" % text[count]
487             count += 1
488             self._commit_time = int(text[count:count+10])
489             while text[count] != ' ':
490                 assert text[count] != '\n', "Malformed committer information"
491                 count += 1
492             self._commit_timezone = int(text[count:count+6])
493             count += 1
494             while text[count] != '\n':
495                 count += 1
496             count += 1
497         assert text[count] == '\n', "There must be a new line after the headers"
498         count += 1
499         # XXX: There can be an encoding field.
500         self._message = text[count:]
501
502     def serialize(self):
503         self._text = ""
504         self._text += "%s %s\n" % (TREE_ID, self._tree)
505         for p in self._parents:
506             self._text += "%s %s\n" % (PARENT_ID, p)
507         self._text += "%s %s %s %+05d\n" % (AUTHOR_ID, self._author, str(self._author_time), self._author_timezone)
508         self._text += "%s %s %s %+05d\n" % (COMMITTER_ID, self._committer, str(self._commit_time), self._commit_timezone)
509         self._text += "\n" # There must be a new line after the headers
510         self._text += self._message
511
512     @property
513     def tree(self):
514         """Returns the tree that is the state of this commit"""
515         return self._tree
516
517     @property
518     def parents(self):
519         """Return a list of parents of this commit."""
520         return self._parents
521
522     @property
523     def author(self):
524         """Returns the name of the author of the commit"""
525         return self._author
526
527     @property
528     def committer(self):
529         """Returns the name of the committer of the commit"""
530         return self._committer
531
532     @property
533     def message(self):
534         """Returns the commit message"""
535         return self._message
536
537     @property
538     def commit_time(self):
539         """Returns the timestamp of the commit.
540         
541         Returns it as the number of seconds since the epoch.
542         """
543         return self._commit_time
544
545     @property
546     def commit_timezone(self):
547         """Returns the zone the commit time is in
548         """
549         return self._commit_timezone
550
551     @property
552     def author_time(self):
553         """Returns the timestamp the commit was written.
554         
555         Returns it as the number of seconds since the epoch.
556         """
557         return self._author_time
558
559     @property
560     def author_timezone(self):
561         """Returns the zone the author time is in
562         """
563         return self._author_timezone
564
565
566 type_map = {
567     BLOB_ID : Blob,
568     TREE_ID : Tree,
569     COMMIT_ID : Commit,
570     TAG_ID: Tag,
571 }
572
573 num_type_map = {
574     0: None,
575     1: Commit,
576     2: Tree,
577     3: Blob,
578     4: Tag,
579     # 5 Is reserved for further expansion
580 }
581
582 try:
583     # Try to import C versions
584     from dulwich._objects import hex_to_sha, sha_to_hex, parse_tree
585 except ImportError:
586     pass
587