Distinguish between ShaFile.from_file and ShaFile.from_path.
[jelmer/dulwich-libgit2.git] / dulwich / tests / test_objects.py
index 8d8b007eaf7b80d0af4fbe31a188a2009be8cb31..5f32e61785ef3f48cacf3c21b794bbb3adc03353 100644 (file)
@@ -29,6 +29,7 @@ import stat
 import unittest
 
 from dulwich.errors import (
+    ChecksumMismatch,
     ObjectFormatException,
     )
 from dulwich.objects import (
@@ -38,6 +39,8 @@ from dulwich.objects import (
     Tag,
     format_timezone,
     hex_to_sha,
+    sha_to_hex,
+    hex_to_filename,
     check_hexsha,
     check_identity,
     parse_timezone,
@@ -87,13 +90,22 @@ except ImportError:
                 return
 
 
+class TestHexToSha(unittest.TestCase):
+
+    def test_simple(self):
+        self.assertEquals("\xab\xcd" * 10, hex_to_sha("abcd" * 10))
+
+    def test_reverse(self):
+        self.assertEquals("abcd" * 10, sha_to_hex("\xab\xcd" * 10))
+
+
 class BlobReadTests(unittest.TestCase):
     """Test decompression of blobs"""
-  
-    def get_sha_file(self, obj, base, sha):
-        return obj.from_file(os.path.join(os.path.dirname(__file__),
-                                          'data', base, sha))
-  
+
+    def get_sha_file(self, cls, base, sha):
+        dir = os.path.join(os.path.dirname(__file__), 'data', base)
+        return cls.from_path(hex_to_filename(dir, sha))
+
     def get_blob(self, sha):
         """Return the blob named sha from the test data dir"""
         return self.get_sha_file(Blob, 'blobs', sha)
@@ -209,20 +221,27 @@ class BlobReadTests(unittest.TestCase):
         self.assertEqual(c.author_timezone, 0)
         self.assertEqual(c.message, 'Merge ../b\n')
 
+    def test_check_id(self):
+        wrong_sha = '1' * 40
+        b = self.get_blob(wrong_sha)
+        self.assertEqual(wrong_sha, b.id)
+        self.assertRaises(ChecksumMismatch, b.check)
+        self.assertEqual('742b386350576589175e374a5706505cbd17680c', b.id)
+
 
 class ShaFileCheckTests(unittest.TestCase):
 
-    def assertCheckFails(self, obj, data):
-        obj.set_raw_string(data)
-        self.assertRaises(ObjectFormatException, obj.check)
+    def assertCheckFails(self, cls, data):
+        obj = cls()
+        def do_check():
+            obj.set_raw_string(data)
+            obj.check()
+        self.assertRaises(ObjectFormatException, do_check)
 
-    def assertCheckSucceeds(self, obj, data):
+    def assertCheckSucceeds(self, cls, data):
+        obj = cls()
         obj.set_raw_string(data)
-        try:
-            obj.check()
-        except ObjectFormatException, e:
-            raise
-            self.fail(e)
+        self.assertEqual(None, obj.check())
 
 
 class CommitSerializationTests(unittest.TestCase):
@@ -343,22 +362,22 @@ class CommitParseTests(ShaFileCheckTests):
         self.assertEquals('UTF-8', c.encoding)
 
     def test_check(self):
-        self.assertCheckSucceeds(Commit(), self.make_commit_text())
-        self.assertCheckSucceeds(Commit(), self.make_commit_text(parents=None))
-        self.assertCheckSucceeds(Commit(),
+        self.assertCheckSucceeds(Commit, self.make_commit_text())
+        self.assertCheckSucceeds(Commit, self.make_commit_text(parents=None))
+        self.assertCheckSucceeds(Commit,
                                  self.make_commit_text(encoding='UTF-8'))
 
-        self.assertCheckFails(Commit(), self.make_commit_text(tree='xxx'))
-        self.assertCheckFails(Commit(), self.make_commit_text(
+        self.assertCheckFails(Commit, self.make_commit_text(tree='xxx'))
+        self.assertCheckFails(Commit, self.make_commit_text(
           parents=[a_sha, 'xxx']))
         bad_committer = "some guy without an email address 1174773719 +0000"
-        self.assertCheckFails(Commit(),
+        self.assertCheckFails(Commit,
                               self.make_commit_text(committer=bad_committer))
-        self.assertCheckFails(Commit(),
+        self.assertCheckFails(Commit,
                               self.make_commit_text(author=bad_committer))
-        self.assertCheckFails(Commit(), self.make_commit_text(author=None))
-        self.assertCheckFails(Commit(), self.make_commit_text(committer=None))
-        self.assertCheckFails(Commit(), self.make_commit_text(
+        self.assertCheckFails(Commit, self.make_commit_text(author=None))
+        self.assertCheckFails(Commit, self.make_commit_text(committer=None))
+        self.assertCheckFails(Commit, self.make_commit_text(
           author=None, committer=None))
 
     def test_check_duplicates(self):
@@ -369,9 +388,9 @@ class CommitParseTests(ShaFileCheckTests):
             text = '\n'.join(lines)
             if lines[i].startswith('parent'):
                 # duplicate parents are ok for now
-                self.assertCheckSucceeds(Commit(), text)
+                self.assertCheckSucceeds(Commit, text)
             else:
-                self.assertCheckFails(Commit(), text)
+                self.assertCheckFails(Commit, text)
 
     def test_check_order(self):
         lines = self.make_commit_lines(parents=[a_sha], encoding='UTF-8')
@@ -382,9 +401,9 @@ class CommitParseTests(ShaFileCheckTests):
             perm = list(perm)
             text = '\n'.join(perm + rest)
             if perm == headers:
-                self.assertCheckSucceeds(Commit(), text)
+                self.assertCheckSucceeds(Commit, text)
             else:
-                self.assertCheckFails(Commit(), text)
+                self.assertCheckFails(Commit, text)
 
 
 class TreeTests(ShaFileCheckTests):
@@ -404,8 +423,9 @@ class TreeTests(ShaFileCheckTests):
         self.assertEquals(["a.c", "a", "a/c"], [p[0] for p in x.iteritems()])
 
     def _do_test_parse_tree(self, parse_tree):
-        o = Tree.from_file(os.path.join(os.path.dirname(__file__), 'data',
-                                        'trees', tree_sha))
+        dir = os.path.join(os.path.dirname(__file__), 'data', 'trees')
+        o = Tree.from_path(hex_to_filename(dir, tree_sha))
+        o._parse_file()
         self.assertEquals([('a', 0100644, a_sha), ('b', 0100644, b_sha)],
                           list(parse_tree(o.as_raw_string())))
 
@@ -418,7 +438,7 @@ class TreeTests(ShaFileCheckTests):
         self._do_test_parse_tree(parse_tree)
 
     def test_check(self):
-        t = Tree()
+        t = Tree
         sha = hex_to_sha(a_sha)
 
         # filenames
@@ -530,26 +550,26 @@ class TagParseTests(ShaFileCheckTests):
         self.assertEquals("v2.6.22-rc7", x.name)
 
     def test_check(self):
-        self.assertCheckSucceeds(Tag(), self.make_tag_text())
-        self.assertCheckFails(Tag(), self.make_tag_text(object_sha=None))
-        self.assertCheckFails(Tag(), self.make_tag_text(object_type_name=None))
-        self.assertCheckFails(Tag(), self.make_tag_text(name=None))
-        self.assertCheckFails(Tag(), self.make_tag_text(name=''))
-        self.assertCheckFails(Tag(), self.make_tag_text(
+        self.assertCheckSucceeds(Tag, self.make_tag_text())
+        self.assertCheckFails(Tag, self.make_tag_text(object_sha=None))
+        self.assertCheckFails(Tag, self.make_tag_text(object_type_name=None))
+        self.assertCheckFails(Tag, self.make_tag_text(name=None))
+        self.assertCheckFails(Tag, self.make_tag_text(name=''))
+        self.assertCheckFails(Tag, self.make_tag_text(
           object_type_name="foobar"))
-        self.assertCheckFails(Tag(), self.make_tag_text(
+        self.assertCheckFails(Tag, self.make_tag_text(
           tagger="some guy without an email address 1183319674 -0700"))
-        self.assertCheckFails(Tag(), self.make_tag_text(
+        self.assertCheckFails(Tag, self.make_tag_text(
           tagger=("Linus Torvalds <torvalds@woody.linux-foundation.org> "
                   "Sun 7 Jul 2007 12:54:34 +0700")))
-        self.assertCheckFails(Tag(), self.make_tag_text(object_sha="xxx"))
+        self.assertCheckFails(Tag, self.make_tag_text(object_sha="xxx"))
 
     def test_check_duplicates(self):
         # duplicate each of the header fields
         for i in xrange(4):
             lines = self.make_tag_lines()
             lines.insert(i, lines[i])
-            self.assertCheckFails(Tag(), '\n'.join(lines))
+            self.assertCheckFails(Tag, '\n'.join(lines))
 
     def test_check_order(self):
         lines = self.make_tag_lines()
@@ -560,9 +580,9 @@ class TagParseTests(ShaFileCheckTests):
             perm = list(perm)
             text = '\n'.join(perm + rest)
             if perm == headers:
-                self.assertCheckSucceeds(Tag(), text)
+                self.assertCheckSucceeds(Tag, text)
             else:
-                self.assertCheckFails(Tag(), text)
+                self.assertCheckFails(Tag, text)
 
 
 class CheckTests(unittest.TestCase):