Clean up file headers.
[jelmer/dulwich-libgit2.git] / dulwich / tests / test_repository.py
index 4aa0e1e4b63c8757224491cd8127a2ebe9fc434d..ee7d861754d77af0b37a68911ed523046a678df2 100644 (file)
@@ -1,23 +1,22 @@
 # test_repository.py -- tests for repository.py
 # Copyright (C) 2007 James Westby <jw+debian@jameswestby.net>
-# 
+#
 # This program is free software; you can redistribute it and/or
 # modify it under the terms of the GNU General Public License
 # as published by the Free Software Foundation; version 2
-# of the License or (at your option) any later version of 
+# of the License or (at your option) any later version of
 # the License.
-# 
+#
 # This program is distributed in the hope that it will be useful,
 # but WITHOUT ANY WARRANTY; without even the implied warranty of
 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 # GNU General Public License for more details.
-# 
+#
 # You should have received a copy of the GNU General Public License
 # along with this program; if not, write to the Free Software
 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
 # MA  02110-1301, USA.
 
-
 """Tests for the repository."""
 
 from cStringIO import StringIO
@@ -68,11 +67,11 @@ class RepositoryTests(unittest.TestCase):
     def tearDown(self):
         if self._repo is not None:
             tear_down_repo(self._repo)
-  
+
     def test_simple_props(self):
         r = self._repo = open_repo('a.git')
         self.assertEqual(r.controldir(), r.path)
-  
+
     def test_ref(self):
         r = self._repo = open_repo('a.git')
         self.assertEqual(r.ref('refs/heads/master'),
@@ -83,7 +82,7 @@ class RepositoryTests(unittest.TestCase):
         r["refs/tags/foo"] = 'a90fa2d900a17e99b433217e988c4eb4a2e9a097'
         self.assertEquals('a90fa2d900a17e99b433217e988c4eb4a2e9a097',
                           r["refs/tags/foo"].id)
-  
+
     def test_get_refs(self):
         r = self._repo = open_repo('a.git')
         self.assertEqual({
@@ -92,7 +91,7 @@ class RepositoryTests(unittest.TestCase):
             'refs/tags/mytag': '28237f4dc30d0d462658d6b937b08a0f0b6ef55a',
             'refs/tags/mytag-packed': 'b0931cadc54336e78a1d980420e3268903b57a50',
             }, r.get_refs())
-  
+
     def test_head(self):
         r = self._repo = open_repo('a.git')
         self.assertEqual(r.head(), 'a90fa2d900a17e99b433217e988c4eb4a2e9a097')
@@ -216,14 +215,14 @@ class RepositoryTests(unittest.TestCase):
             self.assertRaises(errors.NotBlobError, r.get_blob, r.head())
         finally:
             warnings.resetwarnings()
-    
+
     def test_linear_history(self):
         r = self._repo = open_repo('a.git')
         history = r.revision_history(r.head())
         shas = [c.sha().hexdigest() for c in history]
         self.assertEqual(shas, [r.head(),
                                 '2a72d929692c41d8554c07f6301757ba18a65d91'])
-  
+
     def test_merge_history(self):
         r = self._repo = open_repo('simple_merge.git')
         history = r.revision_history(r.head())
@@ -233,12 +232,12 @@ class RepositoryTests(unittest.TestCase):
                                 '4cffe90e0a41ad3f5190079d7c8f036bde29cbe6',
                                 '60dacdc733de308bb77bb76ce0fb0f9b44c9769e',
                                 '0d89f20333fbb1d2f3a94da77f4981373d8f4310'])
-  
+
     def test_revision_history_missing_commit(self):
         r = self._repo = open_repo('simple_merge.git')
         self.assertRaises(errors.MissingCommitError, r.revision_history,
                           missing_sha)
-  
+
     def test_out_of_order_merge(self):
         """Test that revision history is ordered by date, not parent order."""
         r = self._repo = open_repo('ooo_merge.git')
@@ -248,7 +247,7 @@ class RepositoryTests(unittest.TestCase):
                                 'f507291b64138b875c28e03469025b1ea20bc614',
                                 'fb5b0425c7ce46959bec94d54b9a157645e114f5',
                                 'f9e39b120c68182a4ba35349f832d0e4e61f485c'])
-  
+
     def test_get_tags_empty(self):
         r = self._repo = open_repo('ooo_merge.git')
         self.assertEqual({}, r.refs.as_dict('refs/tags'))
@@ -306,7 +305,15 @@ class RepositoryTests(unittest.TestCase):
             shutil.rmtree(r1_dir)
             shutil.rmtree(r2_dir)
 
-    def _build_initial_repo(self):
+
+class BuildRepoTests(unittest.TestCase):
+    """Tests that build on-disk repos from scratch.
+
+    Repos live in a temp dir and are torn down after each test. They start with
+    a single commit in master having single file named 'a'.
+    """
+
+    def setUp(self):
         repo_dir = os.path.join(tempfile.mkdtemp(), 'test')
         os.makedirs(repo_dir)
         r = self._repo = Repo.init(repo_dir)
@@ -326,20 +333,21 @@ class RepositoryTests(unittest.TestCase):
                                  commit_timestamp=12345, commit_timezone=0,
                                  author_timestamp=12345, author_timezone=0)
         self.assertEqual([], r[commit_sha].parents)
-        return commit_sha
+        self._root_commit = commit_sha
+
+    def tearDown(self):
+        tear_down_repo(self._repo)
 
     def test_build_repo(self):
-        commit_sha = self._build_initial_repo()
         r = self._repo
         self.assertEqual('ref: refs/heads/master', r.refs.read_ref('HEAD'))
-        self.assertEqual(commit_sha, r.refs['refs/heads/master'])
+        self.assertEqual(self._root_commit, r.refs['refs/heads/master'])
         expected_blob = objects.Blob.from_string('file contents')
         self.assertEqual(expected_blob.data, r[expected_blob.id].data)
-        actual_commit = r[commit_sha]
+        actual_commit = r[self._root_commit]
         self.assertEqual('msg', actual_commit.message)
 
     def test_commit_modified(self):
-        parent_sha = self._build_initial_repo()
         r = self._repo
         f = open(os.path.join(r.path, 'a'), 'wb')
         try:
@@ -352,12 +360,11 @@ class RepositoryTests(unittest.TestCase):
                                  author='Test Author <test@nodomain.com>',
                                  commit_timestamp=12395, commit_timezone=0,
                                  author_timestamp=12395, author_timezone=0)
-        self.assertEqual([parent_sha], r[commit_sha].parents)
+        self.assertEqual([self._root_commit], r[commit_sha].parents)
         _, blob_id = tree_lookup_path(r.get_object, r[commit_sha].tree, 'a')
         self.assertEqual('new contents', r[blob_id].data)
 
     def test_commit_deleted(self):
-        parent_sha = self._build_initial_repo()
         r = self._repo
         os.remove(os.path.join(r.path, 'a'))
         r.stage(['a'])
@@ -366,40 +373,40 @@ class RepositoryTests(unittest.TestCase):
                                  author='Test Author <test@nodomain.com>',
                                  commit_timestamp=12395, commit_timezone=0,
                                  author_timestamp=12395, author_timezone=0)
-        self.assertEqual([parent_sha], r[commit_sha].parents)
+        self.assertEqual([self._root_commit], r[commit_sha].parents)
         self.assertEqual([], list(r.open_index()))
         tree = r[r[commit_sha].tree]
         self.assertEqual([], tree.iteritems())
 
     def test_commit_fail_ref(self):
-        repo_dir = os.path.join(tempfile.mkdtemp(), 'test')
-        os.makedirs(repo_dir)
-        r = self._repo = Repo.init(repo_dir)
+        r = self._repo
 
         def set_if_equals(name, old_ref, new_ref):
-            self.fail('Unexpected call to set_if_equals')
+            return False
         r.refs.set_if_equals = set_if_equals
 
         def add_if_new(name, new_ref):
-            return False
+            self.fail('Unexpected call to add_if_new')
         r.refs.add_if_new = add_if_new
 
+        old_shas = set(r.object_store)
         self.assertRaises(errors.CommitError, r.do_commit, 'failed commit',
                           committer='Test Committer <test@nodomain.com>',
                           author='Test Author <test@nodomain.com>',
                           commit_timestamp=12345, commit_timezone=0,
                           author_timestamp=12345, author_timezone=0)
-        shas = list(r.object_store)
-        self.assertEqual(2, len(shas))
-        for sha in shas:
-            obj = r[sha]
-            if isinstance(obj, objects.Commit):
-                commit = obj
-            elif isinstance(obj, objects.Tree):
-                tree = obj
-            else:
-                self.fail('Unexpected object found: %s' % sha)
-        self.assertEqual(tree.id, commit.tree)
+        new_shas = set(r.object_store) - old_shas
+        self.assertEqual(1, len(new_shas))
+        # Check that the new commit (now garbage) was added.
+        new_commit = r[new_shas.pop()]
+        self.assertEqual(r[self._root_commit].tree, new_commit.tree)
+        self.assertEqual('failed commit', new_commit.message)
+
+    def test_stage_deleted(self):
+        r = self._repo
+        os.remove(os.path.join(r.path, 'a'))
+        r.stage(['a'])
+        r.stage(['a'])  # double-stage a deleted path
 
 
 class CheckRefFormatTests(unittest.TestCase):
@@ -727,6 +734,19 @@ class DiskRefsContainerTests(RefsContainerTests, unittest.TestCase):
         self.assertFalse(os.path.exists(
             os.path.join(self._refs.path, 'HEAD.lock')))
 
+    def test_remove_packed_without_peeled(self):
+        refs_file = os.path.join(self._repo.path, 'packed-refs')
+        f = open(refs_file)
+        refs_data = f.read()
+        f.close()
+        f = open(refs_file, 'w')
+        f.write('\n'.join(l for l in refs_data.split('\n')
+                          if not l or l[0] not in '#^'))
+        f.close()
+        self._repo = Repo(self._repo.path)
+        refs = self._repo.refs
+        self.assertTrue(refs.remove_if_equals(
+          'refs/heads/packed', '42d06bd4b77fed026b154d16493e5deab78f02ec'))
 
     def test_remove_if_equals_packed(self):
         # test removing ref that is only packed
@@ -739,7 +759,7 @@ class DiskRefsContainerTests(RefsContainerTests, unittest.TestCase):
 
     def test_read_ref(self):
         self.assertEqual('ref: refs/heads/master', self._refs.read_ref("HEAD"))
-        self.assertEqual('42d06bd4b77fed026b154d16493e5deab78f02ec', 
+        self.assertEqual('42d06bd4b77fed026b154d16493e5deab78f02ec',
             self._refs.read_ref("refs/heads/packed"))
         self.assertEqual(None,
             self._refs.read_ref("nonexistant"))