Check tag and commit objects for duplicate and out-of-order headers.
authorDave Borowitz <dborowitz@google.com>
Sat, 27 Feb 2010 19:29:44 +0000 (11:29 -0800)
committerJelmer Vernooij <jelmer@samba.org>
Mon, 12 Apr 2010 15:07:31 +0000 (08:07 -0700)
This requires factoring out the commit/tag parsing code so it can be
used directly from both the _parse_text and check methods. The parse
methods yield tuples, which can either be used to set members or
check for ordering.

Change-Id: I5ffe47100273912eaa283d03332286287b109a13

dulwich/objects.py
dulwich/tests/test_objects.py

index ca6b271d30392260104b173770ebc4fb5da49983..c3d296b774a7d82583840c8dcb5f8c692ca8b7f2 100644 (file)
@@ -393,6 +393,29 @@ class Blob(ShaFile):
         pass  # it's impossible for raw data to be malformed
 
 
+def _parse_tag_or_commit(text):
+    """Parse tag or commit text.
+
+    :param text: the raw text of the tag or commit object.
+    :yield: tuples of (field, value), one per header line, in the order read
+        from the text, possibly including duplicates. Includes a field named
+        None for the freeform tag/commit text.
+    """
+    f = StringIO(text)
+    for l in f:
+        l = l.rstrip("\n")
+        if l == "":
+            # Empty line indicates end of headers
+            break
+        yield l.split(" ", 1)
+    yield (None, f.read())
+    f.close()
+
+
+def parse_tag(text):
+    return _parse_tag_or_commit(text)
+
+
 class Tag(ShaFile):
     """A Git Tag object."""
 
@@ -425,7 +448,6 @@ class Tag(ShaFile):
         :raise ObjectFormatException: if the object is malformed in some way
         """
         super(Tag, self).check()
-        # TODO(dborowitz): check header order
         self._check_has_member("_object_sha", "missing object sha")
         self._check_has_member("_object_class", "missing object type")
         self._check_has_member("_name", "missing tag name")
@@ -438,6 +460,18 @@ class Tag(ShaFile):
         if getattr(self, "_tagger", None):
             check_identity(self._tagger, "invalid tagger")
 
+        last = None
+        for field, _ in parse_tag("".join(self._chunked_text)):
+            if field == _OBJECT_HEADER and last is not None:
+                raise ObjectFormatException("unexpected object")
+            elif field == _TYPE_HEADER and last != _OBJECT_HEADER:
+                raise ObjectFormatException("unexpected type")
+            elif field == _TAG_HEADER and last != _TYPE_HEADER:
+                raise ObjectFormatException("unexpected tag name")
+            elif field == _TAGGER_HEADER and last != _TAG_HEADER:
+                raise ObjectFormatException("unexpected tagger")
+            last = field
+
     def _serialize(self):
         chunks = []
         chunks.append("%s %s\n" % (_OBJECT_HEADER, self._object_sha))
@@ -458,12 +492,7 @@ class Tag(ShaFile):
     def _deserialize(self, chunks):
         """Grab the metadata attached to the tag"""
         self._tagger = None
-        f = StringIO("".join(chunks))
-        for l in f:
-            l = l.rstrip("\n")
-            if l == "":
-                break # empty line indicates end of headers
-            (field, value) = l.split(" ", 1)
+        for field, value in parse_tag("".join(chunks)):
             if field == _OBJECT_HEADER:
                 self._object_sha = value
             elif field == _TYPE_HEADER:
@@ -484,9 +513,10 @@ class Tag(ShaFile):
                     self._tag_time = int(timetext)
                     self._tag_timezone, self._tag_timezone_neg_utc = \
                             parse_timezone(timezonetext)
+            elif field is None:
+                self._message = value
             else:
                 raise AssertionError("Unknown field %s" % field)
-        self._message = f.read()
 
     def _get_object(self):
         """Get the object pointed to by this tag.
@@ -702,6 +732,10 @@ def format_timezone(offset, negative_utc=False):
     return '%c%02d%02d' % (sign, offset / 3600, (offset / 60) % 60)
 
 
+def parse_commit(text):
+    return _parse_tag_or_commit(text)
+
+
 class Commit(ShaFile):
     """A git commit object"""
 
@@ -729,13 +763,7 @@ class Commit(ShaFile):
         self._parents = []
         self._extra = []
         self._author = None
-        f = StringIO("".join(chunks))
-        for l in f:
-            l = l.rstrip("\n")
-            if l == "":
-                # Empty line indicates end of headers
-                break
-            (field, value) = l.split(" ", 1)
+        for field, value in parse_commit("".join(self._chunked_text)):
             if field == _TREE_HEADER:
                 self._tree = value
             elif field == _PARENT_HEADER:
@@ -752,9 +780,10 @@ class Commit(ShaFile):
                     parse_timezone(timezonetext)
             elif field == _ENCODING_HEADER:
                 self._encoding = value
+            elif field is None:
+                self._message = value
             else:
                 self._extra.append((field, value))
-        self._message = f.read()
 
     def check(self):
         """Check this object for internal consistency.
@@ -762,8 +791,6 @@ class Commit(ShaFile):
         :raise ObjectFormatException: if the object is malformed in some way
         """
         super(Commit, self).check()
-        # TODO(dborowitz): check header order
-        # TODO(dborowitz): check for duplicate headers
         self._check_has_member("_tree", "missing tree")
         self._check_has_member("_author", "missing author")
         self._check_has_member("_committer", "missing committer")
@@ -776,6 +803,24 @@ class Commit(ShaFile):
         check_identity(self._author, "invalid author")
         check_identity(self._committer, "invalid committer")
 
+        last = None
+        for field, _ in parse_commit("".join(self._chunked_text)):
+            if field == _TREE_HEADER and last is not None:
+                raise ObjectFormatException("unexpected tree")
+            elif field == _PARENT_HEADER and last not in (_PARENT_HEADER,
+                                                          _TREE_HEADER):
+                raise ObjectFormatException("unexpected parent")
+            elif field == _AUTHOR_HEADER and last not in (_TREE_HEADER,
+                                                          _PARENT_HEADER):
+                raise ObjectFormatException("unexpected author")
+            elif field == _COMMITTER_HEADER and last != _AUTHOR_HEADER:
+                raise ObjectFormatException("unexpected committer")
+            elif field == _ENCODING_HEADER and last != _COMMITTER_HEADER:
+                raise ObjectFormatException("unexpected encoding")
+            last = field
+
+        # TODO: optionally check for duplicate parents
+
     def _serialize(self):
         chunks = []
         chunks.append("%s %s\n" % (_TREE_HEADER, self._tree))
index a209a7d5d2ea82320b7fae30b139f5ebb30afae6..8d8b007eaf7b80d0af4fbe31a188a2009be8cb31 100644 (file)
@@ -54,6 +54,39 @@ c_sha = '954a536f7819d40e6f637f849ee187dd10066349'
 tree_sha = '70c190eb48fa8bbb50ddc692a17b44cb781af7f6'
 tag_sha = '71033db03a03c6a36721efcf1968dd8f8e0cf023'
 
+
+try:
+    from itertools import permutations
+except ImportError:
+    # Implementation of permutations from Python 2.6 documentation:
+    # http://docs.python.org/2.6/library/itertools.html#itertools.permutations
+    # Copyright (c) 2001-2010 Python Software Foundation; All Rights Reserved
+    def permutations(iterable, r=None):
+        # permutations('ABCD', 2) --> AB AC AD BA BC BD CA CB CD DA DB DC
+        # permutations(range(3)) --> 012 021 102 120 201 210
+        pool = tuple(iterable)
+        n = len(pool)
+        r = n if r is None else r
+        if r > n:
+            return
+        indices = range(n)
+        cycles = range(n, n-r, -1)
+        yield tuple(pool[i] for i in indices[:r])
+        while n:
+            for i in reversed(range(r)):
+                cycles[i] -= 1
+                if cycles[i] == 0:
+                    indices[i:] = indices[i+1:] + indices[i:i+1]
+                    cycles[i] = n - i
+                else:
+                    j = cycles[i]
+                    indices[i], indices[-j] = indices[-j], indices[i]
+                    yield tuple(pool[i] for i in indices[:r])
+                    break
+            else:
+                return
+
+
 class BlobReadTests(unittest.TestCase):
     """Test decompression of blobs"""
   
@@ -250,15 +283,15 @@ default_committer = 'James Westby <jw+debian@jameswestby.net> 1174773719 +0000'
 
 class CommitParseTests(ShaFileCheckTests):
 
-    def make_commit_text(self,
-                         tree='d80c186a03f423a81b39df39dc87fd269736ca86',
-                         parents=['ab64bbdcc51b170d21588e5c5d391ee5c0c96dfd',
-                                  '4cffe90e0a41ad3f5190079d7c8f036bde29cbe6'],
-                         author=default_committer,
-                         committer=default_committer,
-                         encoding=None,
-                         message='Merge ../b\n',
-                         extra=None):
+    def make_commit_lines(self,
+                          tree='d80c186a03f423a81b39df39dc87fd269736ca86',
+                          parents=['ab64bbdcc51b170d21588e5c5d391ee5c0c96dfd',
+                                   '4cffe90e0a41ad3f5190079d7c8f036bde29cbe6'],
+                          author=default_committer,
+                          committer=default_committer,
+                          encoding=None,
+                          message='Merge ../b\n',
+                          extra=None):
         lines = []
         if tree is not None:
             lines.append('tree %s' % tree)
@@ -276,7 +309,10 @@ class CommitParseTests(ShaFileCheckTests):
         lines.append('')
         if message is not None:
             lines.append(message)
-        return '\n'.join(lines)
+        return lines
+
+    def make_commit_text(self, **kwargs):
+        return '\n'.join(self.make_commit_lines(**kwargs))
 
     def test_simple(self):
         c = Commit.from_string(self.make_commit_text())
@@ -325,6 +361,31 @@ class CommitParseTests(ShaFileCheckTests):
         self.assertCheckFails(Commit(), self.make_commit_text(
           author=None, committer=None))
 
+    def test_check_duplicates(self):
+        # duplicate each of the header fields
+        for i in xrange(5):
+            lines = self.make_commit_lines(parents=[a_sha], encoding='UTF-8')
+            lines.insert(i, lines[i])
+            text = '\n'.join(lines)
+            if lines[i].startswith('parent'):
+                # duplicate parents are ok for now
+                self.assertCheckSucceeds(Commit(), text)
+            else:
+                self.assertCheckFails(Commit(), text)
+
+    def test_check_order(self):
+        lines = self.make_commit_lines(parents=[a_sha], encoding='UTF-8')
+        headers = lines[:5]
+        rest = lines[5:]
+        # of all possible permutations, ensure only the original succeeds
+        for perm in permutations(headers):
+            perm = list(perm)
+            text = '\n'.join(perm + rest)
+            if perm == headers:
+                self.assertCheckSucceeds(Commit(), text)
+            else:
+                self.assertCheckFails(Commit(), text)
+
 
 class TreeTests(ShaFileCheckTests):
 
@@ -425,12 +486,12 @@ OK2XeQOiEeXtT76rV4t2WR4=
 
 
 class TagParseTests(ShaFileCheckTests):
-    def make_tag_text(self,
-                      object_sha="a38d6181ff27824c79fc7df825164a212eff6a3f",
-                      object_type_name="commit",
-                      name="v2.6.22-rc7",
-                      tagger=default_tagger,
-                      message=default_message):
+    def make_tag_lines(self,
+                       object_sha="a38d6181ff27824c79fc7df825164a212eff6a3f",
+                       object_type_name="commit",
+                       name="v2.6.22-rc7",
+                       tagger=default_tagger,
+                       message=default_message):
         lines = []
         if object_sha is not None:
             lines.append("object %s" % object_sha)
@@ -443,7 +504,10 @@ class TagParseTests(ShaFileCheckTests):
         lines.append("")
         if message is not None:
             lines.append(message)
-        return "\n".join(lines)
+        return lines
+
+    def make_tag_text(self, **kwargs):
+        return "\n".join(self.make_tag_lines(**kwargs))
 
     def test_parse(self):
         x = Tag()
@@ -480,6 +544,26 @@ class TagParseTests(ShaFileCheckTests):
                   "Sun 7 Jul 2007 12:54:34 +0700")))
         self.assertCheckFails(Tag(), self.make_tag_text(object_sha="xxx"))
 
+    def test_check_duplicates(self):
+        # duplicate each of the header fields
+        for i in xrange(4):
+            lines = self.make_tag_lines()
+            lines.insert(i, lines[i])
+            self.assertCheckFails(Tag(), '\n'.join(lines))
+
+    def test_check_order(self):
+        lines = self.make_tag_lines()
+        headers = lines[:4]
+        rest = lines[4:]
+        # of all possible permutations, ensure only the original succeeds
+        for perm in permutations(headers):
+            perm = list(perm)
+            text = '\n'.join(perm + rest)
+            if perm == headers:
+                self.assertCheckSucceeds(Tag(), text)
+            else:
+                self.assertCheckFails(Tag(), text)
+
 
 class CheckTests(unittest.TestCase):