Correctly avoid parsing ShaFiles with fixed SHAs when calling sha().
authorDave Borowitz <dborowitz@google.com>
Mon, 8 Mar 2010 18:47:05 +0000 (10:47 -0800)
committerDave Borowitz <dborowitz@google.com>
Fri, 16 Apr 2010 18:56:53 +0000 (11:56 -0700)
This required some reworking of the _needs_* and _sha ivars. Improved
check() to force computing the SHA and verifying that it matches the
previously-set value. Added a test for this check.

Change-Id: I6782693d7d7708bc7c28f357c27419d51409b884

dulwich/_objects.c
dulwich/objects.py
dulwich/tests/data/blobs/11/11111111111111111111111111111111111111 [new file with mode: 0644]
dulwich/tests/test_objects.py

index fef82e78f791f35f5cbc9feecab8eea7584d7715..d9699ec6cb4a72e52873929f679048e32d176819 100644 (file)
@@ -60,7 +60,7 @@ static PyObject *py_parse_tree(PyObject *self, PyObject *args)
                mode = strtol(text, &text, 8);
 
                if (*text != ' ') {
-                       PyErr_SetString(PyExc_RuntimeError, "Expected space");
+                       PyErr_SetString(PyExc_ValueError, "Expected space");
                        Py_DECREF(ret);
                        return NULL;
                }
@@ -76,7 +76,7 @@ static PyObject *py_parse_tree(PyObject *self, PyObject *args)
                }
 
                if (text + namelen + 20 >= end) {
-                       PyErr_SetString(PyExc_RuntimeError, "SHA truncated");
+                       PyErr_SetString(PyExc_ValueError, "SHA truncated");
                        Py_DECREF(ret);
                        Py_DECREF(name);
                        return NULL;
index 562958f07608240a00d9aac5a3b7def1e5739c57..237c6fa52b88cec19922ea20b14e9d739eb477be 100644 (file)
@@ -31,6 +31,7 @@ import stat
 import zlib
 
 from dulwich.errors import (
+    ChecksumMismatch,
     NotBlobError,
     NotCommitError,
     NotTagError,
@@ -208,7 +209,7 @@ class ShaFile(object):
     def as_raw_chunks(self):
         if self._needs_parsing:
             self._ensure_parsed()
-        else:
+        elif self._needs_serialization:
             self._chunked_text = self._serialize()
         return self._chunked_text
 
@@ -239,8 +240,9 @@ class ShaFile(object):
 
     def set_raw_chunks(self, chunks):
         self._chunked_text = chunks
+        self._deserialize(chunks)
         self._sha = None
-        self._needs_parsing = True
+        self._needs_parsing = False
         self._needs_serialization = False
 
     @staticmethod
@@ -370,15 +372,22 @@ class ShaFile(object):
         """Check this object for internal consistency.
 
         :raise ObjectFormatException: if the object is malformed in some way
+        :raise ChecksumMismatch: if the object was created with a SHA that does
+            not match its contents
         """
         # TODO: if we find that error-checking during object parsing is a
         # performance bottleneck, those checks should be moved to the class's
         # check() method during optimization so we can still check the object
         # when necessary.
+        old_sha = self.id
         try:
             self._deserialize(self.as_raw_chunks())
+            self._sha = None
+            new_sha = self.id
         except Exception, e:
             raise ObjectFormatException(e)
+        if old_sha != new_sha:
+            raise ChecksumMismatch(new_sha, old_sha)
 
     def _header(self):
         return "%s %lu\0" % (self.type_name, self.raw_length())
@@ -399,8 +408,13 @@ class ShaFile(object):
 
     def sha(self):
         """The SHA1 object that is the name of this object."""
-        if self._needs_serialization or self._sha is None:
-            self._sha = self._make_sha()
+        if self._sha is None:
+            # this is a local because as_raw_chunks() overwrites self._sha
+            new_sha = make_sha()
+            new_sha.update(self._header())
+            for chunk in self.as_raw_chunks():
+                new_sha.update(chunk)
+            self._sha = new_sha
         return self._sha
 
     @property
@@ -483,7 +497,7 @@ class Blob(ShaFile):
 
         :raise ObjectFormatException: if the object is malformed in some way
         """
-        pass  # it's impossible for raw data to be malformed
+        super(Blob, self).check()
 
 
 def _parse_tag_or_commit(text):
@@ -580,7 +594,10 @@ class Tag(ShaFile):
             if field == _OBJECT_HEADER:
                 self._object_sha = value
             elif field == _TYPE_HEADER:
-                self._object_class = object_class(value)
+                obj_class = object_class(value)
+                if not obj_class:
+                    raise ObjectFormatException("Not a known type: %s" % value)
+                self._object_class = obj_class
             elif field == _TAG_HEADER:
                 self._name = value
             elif field == _TAGGER_HEADER:
@@ -593,14 +610,17 @@ class Tag(ShaFile):
                     self._tag_timezone_neg_utc = False
                 else:
                     self._tagger = value[0:sep+1]
-                    (timetext, timezonetext) = value[sep+2:].rsplit(" ", 1)
-                    self._tag_time = int(timetext)
-                    self._tag_timezone, self._tag_timezone_neg_utc = \
-                            parse_timezone(timezonetext)
+                    try:
+                        (timetext, timezonetext) = value[sep+2:].rsplit(" ", 1)
+                        self._tag_time = int(timetext)
+                        self._tag_timezone, self._tag_timezone_neg_utc = \
+                                parse_timezone(timezonetext)
+                    except ValueError, e:
+                        raise ObjectFormatException(e)
             elif field is None:
                 self._message = value
             else:
-                raise AssertionError("Unknown field %s" % field)
+                raise ObjectFormatError("Unknown field %s" % field)
 
     def _get_object(self):
         """Get the object pointed to by this tag.
@@ -746,7 +766,10 @@ class Tree(ShaFile):
 
     def _deserialize(self, chunks):
         """Grab the entries in the tree"""
-        parsed_entries = parse_tree("".join(chunks))
+        try:
+            parsed_entries = parse_tree("".join(chunks))
+        except ValueError, e:
+            raise ObjectFormatException(e)
         # TODO: list comprehension is for efficiency in the common (small) case;
         # if memory efficiency in the large case is a concern, use a genexp.
         self._entries = dict([(n, (m, s)) for n, m, s in parsed_entries])
diff --git a/dulwich/tests/data/blobs/11/11111111111111111111111111111111111111 b/dulwich/tests/data/blobs/11/11111111111111111111111111111111111111
new file mode 100644 (file)
index 0000000..1942d23
Binary files /dev/null and b/dulwich/tests/data/blobs/11/11111111111111111111111111111111111111 differ
index 0d275b1f65abc9af7ca38abdcd7430e7123bbe9e..e5680725bbb07faae54099ba34437a927c66ef5d 100644 (file)
@@ -29,6 +29,7 @@ import stat
 import unittest
 
 from dulwich.errors import (
+    ChecksumMismatch,
     ObjectFormatException,
     )
 from dulwich.objects import (
@@ -210,22 +211,27 @@ class BlobReadTests(unittest.TestCase):
         self.assertEqual(c.author_timezone, 0)
         self.assertEqual(c.message, 'Merge ../b\n')
 
+    def test_check_id(self):
+        wrong_sha = '1' * 40
+        b = self.get_blob(wrong_sha)
+        self.assertEqual(wrong_sha, b.id)
+        self.assertRaises(ChecksumMismatch, b.check)
+        self.assertEqual('742b386350576589175e374a5706505cbd17680c', b.id)
+
 
 class ShaFileCheckTests(unittest.TestCase):
 
     def assertCheckFails(self, cls, data):
         obj = cls()
-        obj.set_raw_string(data)
-        self.assertRaises(ObjectFormatException, obj.check)
+        def do_check():
+            obj.set_raw_string(data)
+            obj.check()
+        self.assertRaises(ObjectFormatException, do_check)
 
     def assertCheckSucceeds(self, cls, data):
         obj = cls()
         obj.set_raw_string(data)
-        try:
-            obj.check()
-        except ObjectFormatException, e:
-            raise
-            self.fail(e)
+        self.assertEqual(None, obj.check())
 
 
 class CommitSerializationTests(unittest.TestCase):