Set _num_type so commit writer stuff works
[jelmer/dulwich-libgit2.git] / dulwich / objects.py
1 # objects.py -- Acces to base git objects
2 # Copyright (C) 2007 James Westby <jw+debian@jameswestby.net>
3 # The header parsing code is based on that from git itself, which is
4 # Copyright (C) 2005 Linus Torvalds
5 # and licensed under v2 of the GPL.
6
7 # This program is free software; you can redistribute it and/or
8 # modify it under the terms of the GNU General Public License
9 # as published by the Free Software Foundation; version 2
10 # of the License.
11
12 # This program is distributed in the hope that it will be useful,
13 # but WITHOUT ANY WARRANTY; without even the implied warranty of
14 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 # GNU General Public License for more details.
16
17 # You should have received a copy of the GNU General Public License
18 # along with this program; if not, write to the Free Software
19 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
20 # MA  02110-1301, USA.
21
22 import mmap
23 import os
24 import sha
25 import zlib
26
27 from errors import (NotCommitError,
28                     NotTreeError,
29                     NotBlobError,
30                     )
31
32 BLOB_ID = "blob"
33 TAG_ID = "tag"
34 TREE_ID = "tree"
35 COMMIT_ID = "commit"
36 PARENT_ID = "parent"
37 AUTHOR_ID = "author"
38 COMMITTER_ID = "committer"
39
40 def _decompress(string):
41     dcomp = zlib.decompressobj()
42     dcomped = dcomp.decompress(string)
43     dcomped += dcomp.flush()
44     return dcomped
45
46 def sha_to_hex(sha):
47   """Takes a string and returns the hex of the sha within"""
48   hexsha = ''
49   for c in sha:
50     hexsha += "%02x" % ord(c)
51   assert len(hexsha) == 40, "Incorrect length of sha1 string: %d" % \
52          len(hexsha)
53   return hexsha
54
55 def hex_to_sha(hex):
56   """Takes a hex sha and returns a binary sha"""
57   sha = ''
58   for i in range(0,19):
59     sha += chr(int(hex[i:i+2], 16))
60   assert len(sha) == 20, "Incorrent length of sha1"
61   return sha
62
63 class ShaFile(object):
64   """A git SHA file."""
65
66   @classmethod
67   def _parse_legacy_object(cls, map):
68     """Parse a legacy object, creating it and setting object._text"""
69     text = _decompress(map)
70     object = None
71     for posstype in type_map.keys():
72       if text.startswith(posstype):
73         object = type_map[posstype]()
74         text = text[len(posstype):]
75         break
76     assert object is not None, "%s is not a known object type" % text[:9]
77     assert text[0] == ' ', "%s is not a space" % text[0]
78     text = text[1:]
79     size = 0
80     i = 0
81     while text[0] >= '0' and text[0] <= '9':
82       if i > 0 and size == 0:
83         assert False, "Size is not in canonical format"
84       size = (size * 10) + int(text[0])
85       text = text[1:]
86       i += 1
87     object._size = size
88     assert text[0] == "\0", "Size not followed by null"
89     text = text[1:]
90     object._text = text
91     return object
92
93   def as_raw_string(self):
94     return self._num_type, self._text
95
96   @classmethod
97   def _parse_object(cls, map):
98     """Parse a new style object , creating it and setting object._text"""
99     used = 0
100     byte = ord(map[used])
101     used += 1
102     num_type = (byte >> 4) & 7
103     try:
104       object = num_type_map[num_type]()
105     except KeyError:
106       assert False, "Not a known type: %d" % num_type
107     while((byte & 0x80) != 0):
108       byte = ord(map[used])
109       used += 1
110     raw = map[used:]
111     object._text = _decompress(raw)
112     return object
113
114   @classmethod
115   def _parse_file(cls, map):
116     word = (ord(map[0]) << 8) + ord(map[1])
117     if ord(map[0]) == 0x78 and (word % 31) == 0:
118       return cls._parse_legacy_object(map)
119     else:
120       return cls._parse_object(map)
121
122   def __init__(self):
123     """Don't call this directly"""
124
125   def _parse_text(self):
126     """For subclasses to do initialisation time parsing"""
127
128   @classmethod
129   def from_file(cls, filename):
130     """Get the contents of a SHA file on disk"""
131     size = os.path.getsize(filename)
132     f = open(filename, 'rb')
133     try:
134       map = mmap.mmap(f.fileno(), size, access=mmap.ACCESS_READ)
135       shafile = cls._parse_file(map)
136       shafile._parse_text()
137       return shafile
138     finally:
139       f.close()
140
141   @classmethod
142   def from_raw_string(cls, type, string):
143     """Creates an object of the indicated type from the raw string given.
144
145     Type is the numeric type of an object. String is the raw uncompressed
146     contents.
147     """
148     real_class = num_type_map[type]
149     obj = real_class()
150     obj._num_type = type
151     obj._text = string
152     obj._parse_text()
153     return obj
154
155   def _header(self):
156     return "%s %lu\0" % (self._type, len(self._text))
157
158   def crc32(self):
159     return zlib.crc32(self._text)
160
161   def sha(self):
162     """The SHA1 object that is the name of this object."""
163     ressha = sha.new()
164     ressha.update(self._header())
165     ressha.update(self._text)
166     return ressha
167
168   @property
169   def id(self):
170       return self.sha().hexdigest()
171
172   def __repr__(self):
173     return "<%s %s>" % (self.__class__.__name__, self.id)
174
175   def __eq__(self, other):
176     """Return true id the sha of the two objects match.
177
178     The __le__ etc methods aren't overriden as they make no sense,
179     certainly at this level.
180     """
181     return self.sha().digest() == other.sha().digest()
182
183
184 class Blob(ShaFile):
185   """A Git Blob object."""
186
187   _type = BLOB_ID
188   _num_type = 3
189
190   @property
191   def data(self):
192     """The text contained within the blob object."""
193     return self._text
194
195   @classmethod
196   def from_file(cls, filename):
197     blob = ShaFile.from_file(filename)
198     if blob._type != cls._type:
199       raise NotBlobError(filename)
200     return blob
201
202   @classmethod
203   def from_string(cls, string):
204     """Create a blob from a string."""
205     shafile = cls()
206     shafile._text = string
207     return shafile
208
209
210 class Tag(ShaFile):
211   """A Git Tag object."""
212
213   _type = TAG_ID
214
215   @classmethod
216   def from_file(cls, filename):
217     blob = ShaFile.from_file(filename)
218     if blob._type != cls._type:
219       raise NotBlobError(filename)
220     return blob
221
222   @classmethod
223   def from_string(cls, string):
224     """Create a blob from a string."""
225     shafile = cls()
226     shafile._text = string
227     return shafile
228
229
230 class Tree(ShaFile):
231   """A Git tree object"""
232
233   _type = TREE_ID
234   _num_type = 2
235
236   def __init__(self):
237     self._entries = []
238
239   @classmethod
240   def from_file(cls, filename):
241     tree = ShaFile.from_file(filename)
242     if tree._type != cls._type:
243       raise NotTreeError(filename)
244     return tree
245
246   def add(self, mode, name, hexsha):
247     self._entries.append((mode, name, hexsha))
248
249   def entries(self):
250     """Return a list of tuples describing the tree entries"""
251     return self._entries
252
253   def _parse_text(self):
254     """Grab the entries in the tree"""
255     count = 0
256     while count < len(self._text):
257       mode = 0
258       chr = self._text[count]
259       while chr != ' ':
260         assert chr >= '0' and chr <= '7', "%s is not a valid mode char" % chr
261         mode = (mode << 3) + (ord(chr) - ord('0'))
262         count += 1
263         chr = self._text[count]
264       count += 1
265       chr = self._text[count]
266       name = ''
267       while chr != '\0':
268         name += chr
269         count += 1
270         chr = self._text[count]
271       count += 1
272       chr = self._text[count]
273       sha = self._text[count:count+20]
274       hexsha = sha_to_hex(sha)
275       self.add(mode, name, hexsha)
276       count = count + 20
277
278   def serialize(self):
279     self._text = ""
280     for mode, name, hexsha in self._entries:
281         self._text += "%04o %s\0%s" % (mode, name, hex_to_sha(hexsha))
282
283
284 class Commit(ShaFile):
285   """A git commit object"""
286
287   _type = COMMIT_ID
288   _num_type = 1
289
290   @classmethod
291   def from_file(cls, filename):
292     commit = ShaFile.from_file(filename)
293     if commit._type != cls._type:
294       raise NotCommitError(filename)
295     return commit
296
297   def _parse_text(self):
298     text = self._text
299     count = 0
300     assert text.startswith(TREE_ID), "Invalid commit object, " \
301          "must start with %s" % TREE_ID
302     count += len(TREE_ID)
303     assert text[count] == ' ', "Invalid commit object, " \
304          "%s must be followed by space not %s" % (TREE_ID, text[count])
305     count += 1
306     self._tree = text[count:count+40]
307     count = count + 40
308     assert text[count] == "\n", "Invalid commit object, " \
309          "tree sha must be followed by newline"
310     count += 1
311     self._parents = []
312     while text[count:].startswith(PARENT_ID):
313       count += len(PARENT_ID)
314       assert text[count] == ' ', "Invalid commit object, " \
315            "%s must be followed by space not %s" % (PARENT_ID, text[count])
316       count += 1
317       self._parents.append(text[count:count+40])
318       count += 40
319       assert text[count] == "\n", "Invalid commit object, " \
320            "parent sha must be followed by newline"
321       count += 1
322     self._author = None
323     if text[count:].startswith(AUTHOR_ID):
324       count += len(AUTHOR_ID)
325       assert text[count] == ' ', "Invalid commit object, " \
326            "%s must be followed by space not %s" % (AUTHOR_ID, text[count])
327       count += 1
328       self._author = ''
329       while text[count] != '>':
330         assert text[count] != '\n', "Malformed author information"
331         self._author += text[count]
332         count += 1
333       self._author += text[count]
334       count += 1
335       while text[count] != '\n':
336         count += 1
337       count += 1
338     self._committer = None
339     if text[count:].startswith(COMMITTER_ID):
340       count += len(COMMITTER_ID)
341       assert text[count] == ' ', "Invalid commit object, " \
342            "%s must be followed by space not %s" % (COMMITTER_ID, text[count])
343       count += 1
344       self._committer = ''
345       while text[count] != '>':
346         assert text[count] != '\n', "Malformed committer information"
347         self._committer += text[count]
348         count += 1
349       self._committer += text[count]
350       count += 1
351       assert text[count] == ' ', "Invalid commit object, " \
352            "commiter information must be followed by space not %s" % text[count]
353       count += 1
354       self._commit_time = int(text[count:count+10])
355       while text[count] != '\n':
356         count += 1
357       count += 1
358     assert text[count] == '\n', "There must be a new line after the headers"
359     count += 1
360     # XXX: There can be an encoding field.
361     self._message = text[count:]
362
363   def serialize(self):
364     self._text = ""
365     self._text += "%s %s\n" % (TREE_ID, self._tree)
366     for p in self._parents:
367       self._text += "%s %s\n" % (PARENT_ID, p)
368     self._text += "%s %s %s +0000\n" % (AUTHOR_ID, self._author, str(self._commit_time))
369     self._text += "%s %s %s +0000\n" % (COMMITTER_ID, self._committer, str(self._commit_time))
370     self._text += message
371
372   @property
373   def tree(self):
374     """Returns the tree that is the state of this commit"""
375     return self._tree
376
377   @property
378   def parents(self):
379     """Return a list of parents of this commit."""
380     return self._parents
381
382   @property
383   def author(self):
384     """Returns the name of the author of the commit"""
385     return self._author
386
387   @property
388   def committer(self):
389     """Returns the name of the committer of the commit"""
390     return self._committer
391
392   @property
393   def message(self):
394     """Returns the commit message"""
395     return self._message
396
397   @property
398   def commit_time(self):
399     """Returns the timestamp of the commit.
400     
401     Returns it as the number of seconds since the epoch.
402     """
403     return self._commit_time
404
405
406 type_map = {
407   BLOB_ID : Blob,
408   TREE_ID : Tree,
409   COMMIT_ID : Commit,
410   TAG_ID: Tag,
411 }
412
413 num_type_map = {
414   0: None,
415   1: Commit,
416   2: Tree,
417   3: Blob,
418   4: Tag,
419   # 5 Is reserved for further expansion
420 }
421