Merge improvements and extra tests, mainly to deal better with creating non-bare...
authorJelmer Vernooij <jelmer@samba.org>
Mon, 17 May 2010 21:50:35 +0000 (23:50 +0200)
committerJelmer Vernooij <jelmer@samba.org>
Mon, 17 May 2010 21:50:35 +0000 (23:50 +0200)
dulwich/_objects.c
dulwich/errors.py
dulwich/index.py
dulwich/objects.py
dulwich/repo.py
dulwich/tests/test_index.py
dulwich/tests/test_repository.py

index 986098bdb0b26d29193b7952b010e08048071135..5e3e845bec79c18072dbd8586d8622035523c459 100644 (file)
@@ -158,7 +158,7 @@ static PyObject *py_sorted_tree_items(PyObject *self, PyObject *entries)
 
        i = 0;
        while (PyDict_Next(entries, &pos, &key, &value)) {
-               PyObject *py_mode, *py_sha;
+               PyObject *py_mode, *py_int_mode, *py_sha;
                
                if (PyTuple_Size(value) != 2) {
                        PyErr_SetString(PyExc_ValueError, "Tuple has invalid size");
@@ -167,20 +167,22 @@ static PyObject *py_sorted_tree_items(PyObject *self, PyObject *entries)
                }
 
                py_mode = PyTuple_GET_ITEM(value, 0);
+               py_int_mode = PyNumber_Int(py_mode);
+               if (!py_int_mode) {
+                       PyErr_SetString(PyExc_TypeError, "Mode is not an integral type");
+                       free(qsort_entries);
+                       return NULL;
+               }
+
                py_sha = PyTuple_GET_ITEM(value, 1);
-               qsort_entries[i].tuple = Py_BuildValue("(OOO)", key, py_mode, py_sha);
                if (!PyString_CheckExact(key)) {
                        PyErr_SetString(PyExc_TypeError, "Name is not a string");
                        free(qsort_entries);
                        return NULL;
                }
                qsort_entries[i].name = PyString_AS_STRING(key);
-               if (!PyInt_CheckExact(py_mode)) {
-                       PyErr_SetString(PyExc_TypeError, "Mode is not an int");
-                       free(qsort_entries);
-                       return NULL;
-               }
                qsort_entries[i].mode = PyInt_AS_LONG(py_mode);
+               qsort_entries[i].tuple = PyTuple_Pack(3, key, py_mode, py_sha);
                i++;
        }
 
index 0514d4d0e8a9e0089c2021adb2ab1bc0c8c2b2e9..80f54b380ce2c537a7d4da6037dcae886e1b5264 100644 (file)
@@ -151,3 +151,7 @@ class ObjectFormatException(FileFormatException):
 
 class NoIndexPresent(Exception):
     """No index is present."""
+
+
+class CommitError(Exception):
+    """An error occurred while performing a commit."""
index b2c2619e86d41358f128568cdc982d01909a4fc4..b6649a8bf9863a85c40df215ac7176af5df5ef94 100644 (file)
@@ -204,6 +204,8 @@ class Index(object):
 
     def read(self):
         """Read current contents of index from disk."""
+        if not os.path.exists(self._filename):
+            return
         f = GitFile(self._filename, 'rb')
         try:
             f = SHA1Reader(f)
@@ -254,6 +256,10 @@ class Index(object):
         # Remove the old entry if any
         self._byname[name] = x
 
+    def __delitem__(self, name):
+        assert isinstance(name, str)
+        del self._byname[name]
+
     def iteritems(self):
         return self._byname.iteritems()
 
index 969e65341e315752f33ce01defdb15e774947d22..d78e6d7e00b0c95491eccc2cb3be42852fb82ed2 100644 (file)
@@ -747,10 +747,16 @@ class Tree(ShaFile):
         return self._entries[name]
 
     def __setitem__(self, name, value):
-        assert isinstance(value, tuple)
-        assert len(value) == 2
+        """Set a tree entry by name.
+
+        :param name: The name of the entry, as a string.
+        :param value: A tuple of (mode, hexsha), where mode is the mode of the
+            entry as an integral type and hexsha is the hex SHA of the entry as
+            a string.
+        """
+        mode, hexsha = value
         self._ensure_parsed()
-        self._entries[name] = value
+        self._entries[name] = (mode, hexsha)
         self._needs_serialization = True
 
     def __delitem__(self, name):
@@ -767,9 +773,13 @@ class Tree(ShaFile):
         return iter(self._entries)
 
     def add(self, mode, name, hexsha):
-        assert type(mode) == int
-        assert type(name) == str
-        assert type(hexsha) == str
+        """Add an entry to the tree.
+
+        :param mode: The mode of the entry as an integral type. Not all possible
+            modes are supported by git; see check() for details.
+        :param name: The name of the entry, as a string.
+        :param hexsha: The hex SHA of the entry as a string.
+        """
         self._ensure_parsed()
         self._entries[name] = mode, hexsha
         self._needs_serialization = True
index 3148b43712c7cd7db0034190aa15f6568bdbc7cc..6cf5df1940a793f46115c29e3560df43efaf28e5 100644 (file)
@@ -34,6 +34,7 @@ from dulwich.errors import (
     NotTreeError,
     NotTagError,
     PackedRefsException,
+    CommitError,
     )
 from dulwich.file import (
     ensure_dir_exists,
@@ -126,7 +127,7 @@ class RefsContainer(object):
         :param name: Name of the ref to set
         :param other: Name of the ref to point at
         """
-        self[name] = SYMREF + other + '\n'
+        raise NotImplementedError(self.set_symbolic_ref)
 
     def get_packed_refs(self):
         """Get contents of the packed-refs file.
@@ -152,10 +153,14 @@ class RefsContainer(object):
         for name, value in other.iteritems():
             self["%s/%s" % (base, name)] = value
 
+    def allkeys(self):
+        """All refs present in this container."""
+        raise NotImplementedError(self.allkeys)
+
     def keys(self, base=None):
         """Refs present in this container.
 
-        :param base: An optional base to return refs under
+        :param base: An optional base to return refs under.
         :return: An unsorted set of valid refs in this container, including
             packed refs.
         """
@@ -165,10 +170,17 @@ class RefsContainer(object):
             return self.allkeys()
 
     def subkeys(self, base):
+        """Refs present in this container under a base.
+
+        :param base: The base to return refs under.
+        :return: A set of valid refs in this container under the base; the base
+            prefix is stripped from the ref names returned.
+        """
         keys = set()
+        base_len = len(base) + 1
         for refname in self.allkeys():
             if refname.startswith(base):
-                keys.add(refname)
+                keys.add(refname[base_len:])
         return keys
 
     def as_dict(self, base=None):
@@ -258,8 +270,74 @@ class RefsContainer(object):
             raise KeyError(name)
         return sha
 
+    def set_if_equals(self, name, old_ref, new_ref):
+        """Set a refname to new_ref only if it currently equals old_ref.
+
+        This method follows all symbolic references if applicable for the
+        subclass, and can be used to perform an atomic compare-and-swap
+        operation.
+
+        :param name: The refname to set.
+        :param old_ref: The old sha the refname must refer to, or None to set
+            unconditionally.
+        :param new_ref: The new sha the refname will refer to.
+        :return: True if the set was successful, False otherwise.
+        """
+        raise NotImplementedError(self.set_if_equals)
+
+    def add_if_new(self, name, ref):
+        """Add a new reference only if it does not already exist."""
+        raise NotImplementedError(self.add_if_new)
+
+    def __setitem__(self, name, ref):
+        """Set a reference name to point to the given SHA1.
+
+        This method follows all symbolic references if applicable for the
+        subclass.
+
+        :note: This method unconditionally overwrites the contents of a
+            reference. To update atomically only if the reference has not
+            changed, use set_if_equals().
+        :param name: The refname to set.
+        :param ref: The new sha the refname will refer to.
+        """
+        self.set_if_equals(name, None, ref)
+
+    def remove_if_equals(self, name, old_ref):
+        """Remove a refname only if it currently equals old_ref.
+
+        This method does not follow symbolic references, even if applicable for
+        the subclass. It can be used to perform an atomic compare-and-delete
+        operation.
+
+        :param name: The refname to delete.
+        :param old_ref: The old sha the refname must refer to, or None to delete
+            unconditionally.
+        :return: True if the delete was successful, False otherwise.
+        """
+        raise NotImplementedError(self.remove_if_equals)
+
+    def __delitem__(self, name):
+        """Remove a refname.
+
+        This method does not follow symbolic references, even if applicable for
+        the subclass.
+
+        :note: This method unconditionally deletes the contents of a reference.
+            To delete atomically only if the reference has not changed, use
+            remove_if_equals().
+
+        :param name: The refname to delete.
+        """
+        self.remove_if_equals(name, None)
+
 
 class DictRefsContainer(RefsContainer):
+    """RefsContainer backed by a simple dict.
+
+    This container does not support symbolic or packed references and is not
+    threadsafe.
+    """
 
     def __init__(self, refs):
         self._refs = refs
@@ -268,10 +346,32 @@ class DictRefsContainer(RefsContainer):
         return self._refs.keys()
 
     def read_loose_ref(self, name):
-        return self._refs[name]
+        return self._refs.get(name, None)
 
-    def __setitem__(self, name, value):
-        self._refs[name] = value
+    def get_packed_refs(self):
+        return {}
+
+    def set_symbolic_ref(self, name, other):
+        self._refs[name] = SYMREF + other
+
+    def set_if_equals(self, name, old_ref, new_ref):
+        if old_ref is not None and self._refs.get(name, None) != old_ref:
+            return False
+        realname, _ = self._follow(name)
+        self._refs[realname] = new_ref
+        return True
+
+    def add_if_new(self, name, ref):
+        if name in self._refs:
+            return False
+        self._refs[name] = ref
+        return True
+
+    def remove_if_equals(self, name, old_ref):
+        if old_ref is not None and self._refs.get(name, None) != old_ref:
+            return False
+        del self._refs[name]
+        return True
 
 
 class DiskRefsContainer(RefsContainer):
@@ -426,6 +526,25 @@ class DiskRefsContainer(RefsContainer):
         finally:
             f.abort()
 
+    def set_symbolic_ref(self, name, other):
+        """Make a ref point at another ref.
+
+        :param name: Name of the ref to set
+        :param other: Name of the ref to point at
+        """
+        self._check_refname(name)
+        self._check_refname(other)
+        filename = self.refpath(name)
+        try:
+            f = GitFile(filename, 'wb')
+            try:
+                f.write(SYMREF + other + '\n')
+            except (IOError, OSError):
+                f.abort()
+                raise
+        finally:
+            f.close()
+
     def set_if_equals(self, name, old_ref, new_ref):
         """Set a refname to new_ref only if it currently equals old_ref.
 
@@ -468,9 +587,23 @@ class DiskRefsContainer(RefsContainer):
         return True
 
     def add_if_new(self, name, ref):
-        """Add a new reference only if it does not already exist."""
-        self._check_refname(name)
-        filename = self.refpath(name)
+        """Add a new reference only if it does not already exist.
+
+        This method follows symrefs, and only ensures that the last ref in the
+        chain does not exist.
+
+        :param name: The refname to set.
+        :param ref: The new sha the refname will refer to.
+        :return: True if the add was successful, False otherwise.
+        """
+        try:
+            realname, contents = self._follow(name)
+            if contents is not None:
+                return False
+        except KeyError:
+            realname = name
+        self._check_refname(realname)
+        filename = self.refpath(realname)
         ensure_dir_exists(os.path.dirname(filename))
         f = GitFile(filename, 'wb')
         try:
@@ -486,17 +619,6 @@ class DiskRefsContainer(RefsContainer):
             f.close()
         return True
 
-    def __setitem__(self, name, ref):
-        """Set a reference name to point to the given SHA1.
-
-        This method follows all symbolic references.
-
-        :note: This method unconditionally overwrites the contents of a reference
-            on disk. To update atomically only if the reference has not changed
-            on disk, use set_if_equals().
-        """
-        self.set_if_equals(name, None, ref)
-
     def remove_if_equals(self, name, old_ref):
         """Remove a refname only if it currently equals old_ref.
 
@@ -531,16 +653,6 @@ class DiskRefsContainer(RefsContainer):
             f.abort()
         return True
 
-    def __delitem__(self, name):
-        """Remove a refname.
-
-        This method does not follow symbolic references.
-        :note: This method unconditionally deletes the contents of a reference
-            on disk. To delete atomically only if the reference has not changed
-            on disk, use set_if_equals().
-        """
-        self.remove_if_equals(name, None)
-
 
 def _split_ref_line(line):
     """Split a single ref line into a tuple of SHA1 and name."""
@@ -917,8 +1029,20 @@ class BaseRepo(object):
             author_timezone = commit_timezone
         c.author_timezone = author_timezone
         c.message = message
-        self.object_store.add_object(c)
-        self.refs["HEAD"] = c.id
+        try:
+            old_head = self.refs["HEAD"]
+            c.parents = [old_head]
+            self.object_store.add_object(c)
+            ok = self.refs.set_if_equals("HEAD", old_head, c.id)
+        except KeyError:
+            c.parents = []
+            self.object_store.add_object(c)
+            ok = self.refs.add_if_new("HEAD", c.id)
+        if not ok:
+            # Fail if the atomic compare-and-swap failed, leaving the commit and
+            # all its objects as garbage.
+            raise CommitError("HEAD changed during commit")
+
         return c.id
 
 
@@ -984,7 +1108,9 @@ class Repo(BaseRepo):
 
     def has_index(self):
         """Check if an index is present."""
-        return os.path.exists(self.index_path())
+        # Bare repos must never have index files; non-bare repos may have a
+        # missing index file, which is treated as empty.
+        return not self.bare
 
     def stage(self, paths):
         """Stage a set of paths.
@@ -994,14 +1120,18 @@ class Repo(BaseRepo):
         from dulwich.index import cleanup_mode
         index = self.open_index()
         for path in paths:
+            full_path = os.path.join(self.path, path)
             blob = Blob()
             try:
-                st = os.stat(path)
+                st = os.stat(full_path)
             except OSError:
                 # File no longer exists
-                del index[path]
+                try:
+                    del index[path]
+                except KeyError:
+                    pass  # Doesn't exist in the index either
             else:
-                f = open(path, 'rb')
+                f = open(full_path, 'rb')
                 try:
                     blob.data = f.read()
                 finally:
index f1d8e1c3bfac0e5508976f7fefc2afbc5740700a..f23079837b1a5ce7f04961cea98b5e6b2797e9e7 100644 (file)
@@ -68,6 +68,11 @@ class SimpleIndexTestCase(IndexTestCase):
                            'e69de29bb2d1d6434b8b29ae775ad8c2e48c5391', 0),
                           self.get_simple_index("index")["bla"])
 
+    def test_empty(self):
+        i = self.get_simple_index("notanindex")
+        self.assertEquals(0, len(i))
+        self.assertFalse(os.path.exists(i._filename))
+
 
 class SimpleIndexWriterTestCase(IndexTestCase):
 
index e1ea7ede772bb7cccd09cbebe512cd2d054a3f01..9918e024ec612097c21133673a04bb04dbea2caa 100644 (file)
@@ -28,9 +28,13 @@ import unittest
 import warnings
 
 from dulwich import errors
+from dulwich.object_store import (
+    tree_lookup_path,
+    )
 from dulwich import objects
 from dulwich.repo import (
     check_ref_format,
+    DictRefsContainer,
     Repo,
     read_packed_refs,
     read_packed_refs_with_peeled,
@@ -303,6 +307,109 @@ class RepositoryTests(unittest.TestCase):
             shutil.rmtree(r2_dir)
 
 
+class BuildRepoTests(unittest.TestCase):
+    """Tests that build on-disk repos from scratch.
+
+    Repos live in a temp dir and are torn down after each test. They start with
+    a single commit in master having single file named 'a'.
+    """
+
+    def setUp(self):
+        repo_dir = os.path.join(tempfile.mkdtemp(), 'test')
+        os.makedirs(repo_dir)
+        r = self._repo = Repo.init(repo_dir)
+        self.assertFalse(r.bare)
+        self.assertEqual('ref: refs/heads/master', r.refs.read_ref('HEAD'))
+        self.assertRaises(KeyError, lambda: r.refs['refs/heads/master'])
+
+        f = open(os.path.join(r.path, 'a'), 'wb')
+        try:
+            f.write('file contents')
+        finally:
+            f.close()
+        r.stage(['a'])
+        commit_sha = r.do_commit('msg',
+                                 committer='Test Committer <test@nodomain.com>',
+                                 author='Test Author <test@nodomain.com>',
+                                 commit_timestamp=12345, commit_timezone=0,
+                                 author_timestamp=12345, author_timezone=0)
+        self.assertEqual([], r[commit_sha].parents)
+        self._root_commit = commit_sha
+
+    def tearDown(self):
+        tear_down_repo(self._repo)
+
+    def test_build_repo(self):
+        r = self._repo
+        self.assertEqual('ref: refs/heads/master', r.refs.read_ref('HEAD'))
+        self.assertEqual(self._root_commit, r.refs['refs/heads/master'])
+        expected_blob = objects.Blob.from_string('file contents')
+        self.assertEqual(expected_blob.data, r[expected_blob.id].data)
+        actual_commit = r[self._root_commit]
+        self.assertEqual('msg', actual_commit.message)
+
+    def test_commit_modified(self):
+        r = self._repo
+        f = open(os.path.join(r.path, 'a'), 'wb')
+        try:
+            f.write('new contents')
+        finally:
+            f.close()
+        r.stage(['a'])
+        commit_sha = r.do_commit('modified a',
+                                 committer='Test Committer <test@nodomain.com>',
+                                 author='Test Author <test@nodomain.com>',
+                                 commit_timestamp=12395, commit_timezone=0,
+                                 author_timestamp=12395, author_timezone=0)
+        self.assertEqual([self._root_commit], r[commit_sha].parents)
+        _, blob_id = tree_lookup_path(r.get_object, r[commit_sha].tree, 'a')
+        self.assertEqual('new contents', r[blob_id].data)
+
+    def test_commit_deleted(self):
+        r = self._repo
+        os.remove(os.path.join(r.path, 'a'))
+        r.stage(['a'])
+        commit_sha = r.do_commit('deleted a',
+                                 committer='Test Committer <test@nodomain.com>',
+                                 author='Test Author <test@nodomain.com>',
+                                 commit_timestamp=12395, commit_timezone=0,
+                                 author_timestamp=12395, author_timezone=0)
+        self.assertEqual([self._root_commit], r[commit_sha].parents)
+        self.assertEqual([], list(r.open_index()))
+        tree = r[r[commit_sha].tree]
+        self.assertEqual([], tree.iteritems())
+
+    def test_commit_fail_ref(self):
+        r = self._repo
+
+        def set_if_equals(name, old_ref, new_ref):
+            return False
+        r.refs.set_if_equals = set_if_equals
+
+        def add_if_new(name, new_ref):
+            self.fail('Unexpected call to add_if_new')
+        r.refs.add_if_new = add_if_new
+
+        old_shas = set(r.object_store)
+        self.assertRaises(errors.CommitError, r.do_commit, 'failed commit',
+                          committer='Test Committer <test@nodomain.com>',
+                          author='Test Author <test@nodomain.com>',
+                          commit_timestamp=12345, commit_timezone=0,
+                          author_timestamp=12345, author_timezone=0)
+        new_shas = set(r.object_store) - old_shas
+        self.assertEqual(1, len(new_shas))
+        # Check that the new commit (now garbage) was added.
+        new_commit = r[new_shas.pop()]
+        self.assertEqual(r[self._root_commit].tree, new_commit.tree)
+        self.assertEqual('failed commit', new_commit.message)
+
+    def test_stage_deleted(self):
+        r = self._repo
+        os.remove(os.path.join(r.path, 'a'))
+        r.stage(['a'])
+        r.stage(['a'])  # double-stage a deleted path
+
+
 class CheckRefFormatTests(unittest.TestCase):
     """Tests for the check_ref_format function.
 
@@ -383,7 +490,121 @@ class PackedRefsFileTests(unittest.TestCase):
         self.assertEqual("%s ref/1\n%s ref/2\n" % (ONES, TWOS), f.getvalue())
 
 
-class RefsContainerTests(unittest.TestCase):
+# Dict of refs that we expect all RefsContainerTests subclasses to define.
+_TEST_REFS = {
+  'HEAD': '42d06bd4b77fed026b154d16493e5deab78f02ec',
+  'refs/heads/master': '42d06bd4b77fed026b154d16493e5deab78f02ec',
+  'refs/heads/packed': '42d06bd4b77fed026b154d16493e5deab78f02ec',
+  'refs/tags/refs-0.1': 'df6800012397fb85c56e7418dd4eb9405dee075c',
+  'refs/tags/refs-0.2': '3ec9c43c84ff242e3ef4a9fc5bc111fd780a76a8',
+  }
+
+
+class RefsContainerTests(object):
+
+    def test_keys(self):
+        actual_keys = set(self._refs.keys())
+        self.assertEqual(set(self._refs.allkeys()), actual_keys)
+        # ignore the symref loop if it exists
+        actual_keys.discard('refs/heads/loop')
+        self.assertEqual(set(_TEST_REFS.iterkeys()), actual_keys)
+
+        actual_keys = self._refs.keys('refs/heads')
+        actual_keys.discard('loop')
+        self.assertEqual(['master', 'packed'], sorted(actual_keys))
+        self.assertEqual(['refs-0.1', 'refs-0.2'],
+                         sorted(self._refs.keys('refs/tags')))
+
+    def test_as_dict(self):
+        # refs/heads/loop does not show up even if it exists
+        self.assertEqual(_TEST_REFS, self._refs.as_dict())
+
+    def test_setitem(self):
+        self._refs['refs/some/ref'] = '42d06bd4b77fed026b154d16493e5deab78f02ec'
+        self.assertEqual('42d06bd4b77fed026b154d16493e5deab78f02ec',
+                         self._refs['refs/some/ref'])
+
+    def test_set_if_equals(self):
+        nines = '9' * 40
+        self.assertFalse(self._refs.set_if_equals('HEAD', 'c0ffee', nines))
+        self.assertEqual('42d06bd4b77fed026b154d16493e5deab78f02ec',
+                         self._refs['HEAD'])
+
+        self.assertTrue(self._refs.set_if_equals(
+          'HEAD', '42d06bd4b77fed026b154d16493e5deab78f02ec', nines))
+        self.assertEqual(nines, self._refs['HEAD'])
+
+        self.assertTrue(self._refs.set_if_equals('refs/heads/master', None,
+                                                 nines))
+        self.assertEqual(nines, self._refs['refs/heads/master'])
+
+    def test_add_if_new(self):
+        nines = '9' * 40
+        self.assertFalse(self._refs.add_if_new('refs/heads/master', nines))
+        self.assertEqual('42d06bd4b77fed026b154d16493e5deab78f02ec',
+                         self._refs['refs/heads/master'])
+
+        self.assertTrue(self._refs.add_if_new('refs/some/ref', nines))
+        self.assertEqual(nines, self._refs['refs/some/ref'])
+
+    def test_set_symbolic_ref(self):
+        self._refs.set_symbolic_ref('refs/heads/symbolic', 'refs/heads/master')
+        self.assertEqual('ref: refs/heads/master',
+                         self._refs.read_loose_ref('refs/heads/symbolic'))
+        self.assertEqual('42d06bd4b77fed026b154d16493e5deab78f02ec',
+                         self._refs['refs/heads/symbolic'])
+
+    def test_set_symbolic_ref_overwrite(self):
+        nines = '9' * 40
+        self.assertFalse('refs/heads/symbolic' in self._refs)
+        self._refs['refs/heads/symbolic'] = nines
+        self.assertEqual(nines, self._refs.read_loose_ref('refs/heads/symbolic'))
+        self._refs.set_symbolic_ref('refs/heads/symbolic', 'refs/heads/master')
+        self.assertEqual('ref: refs/heads/master',
+                         self._refs.read_loose_ref('refs/heads/symbolic'))
+        self.assertEqual('42d06bd4b77fed026b154d16493e5deab78f02ec',
+                         self._refs['refs/heads/symbolic'])
+
+    def test_check_refname(self):
+        try:
+            self._refs._check_refname('HEAD')
+        except KeyError:
+            self.fail()
+
+        try:
+            self._refs._check_refname('refs/heads/foo')
+        except KeyError:
+            self.fail()
+
+        self.assertRaises(KeyError, self._refs._check_refname, 'refs')
+        self.assertRaises(KeyError, self._refs._check_refname, 'notrefs/foo')
+
+    def test_contains(self):
+        self.assertTrue('refs/heads/master' in self._refs)
+        self.assertFalse('refs/heads/bar' in self._refs)
+
+    def test_delitem(self):
+        self.assertEqual('42d06bd4b77fed026b154d16493e5deab78f02ec',
+                          self._refs['refs/heads/master'])
+        del self._refs['refs/heads/master']
+        self.assertRaises(KeyError, lambda: self._refs['refs/heads/master'])
+
+    def test_remove_if_equals(self):
+        self.assertFalse(self._refs.remove_if_equals('HEAD', 'c0ffee'))
+        self.assertEqual('42d06bd4b77fed026b154d16493e5deab78f02ec',
+                         self._refs['HEAD'])
+        self.assertTrue(self._refs.remove_if_equals(
+          'refs/tags/refs-0.2', '3ec9c43c84ff242e3ef4a9fc5bc111fd780a76a8'))
+        self.assertFalse('refs/tags/refs-0.2' in self._refs)
+
+
+class DictRefsContainerTests(RefsContainerTests, unittest.TestCase):
+
+    def setUp(self):
+        self._refs = DictRefsContainer(dict(_TEST_REFS))
+
+
+class DiskRefsContainerTests(RefsContainerTests, unittest.TestCase):
 
     def setUp(self):
         self._repo = open_repo('refs.git')
@@ -412,34 +633,8 @@ class RefsContainerTests(unittest.TestCase):
         self.assertEqual('42d06bd4b77fed026b154d16493e5deab78f02ec',
                          self._refs.get_peeled('refs/tags/refs-0.1'))
 
-    def test_keys(self):
-        self.assertEqual([
-          'HEAD',
-          'refs/heads/loop',
-          'refs/heads/master',
-          'refs/heads/packed',
-          'refs/tags/refs-0.1',
-          'refs/tags/refs-0.2',
-          ], sorted(list(self._refs.keys())))
-        self.assertEqual(['loop', 'master', 'packed'],
-                         sorted(self._refs.keys('refs/heads')))
-        self.assertEqual(['refs-0.1', 'refs-0.2'],
-                         sorted(self._refs.keys('refs/tags')))
-
-    def test_as_dict(self):
-        # refs/heads/loop does not show up
-        self.assertEqual({
-          'HEAD': '42d06bd4b77fed026b154d16493e5deab78f02ec',
-          'refs/heads/master': '42d06bd4b77fed026b154d16493e5deab78f02ec',
-          'refs/heads/packed': '42d06bd4b77fed026b154d16493e5deab78f02ec',
-          'refs/tags/refs-0.1': 'df6800012397fb85c56e7418dd4eb9405dee075c',
-          'refs/tags/refs-0.2': '3ec9c43c84ff242e3ef4a9fc5bc111fd780a76a8',
-          }, self._refs.as_dict())
-
     def test_setitem(self):
-        self._refs['refs/some/ref'] = '42d06bd4b77fed026b154d16493e5deab78f02ec'
-        self.assertEqual('42d06bd4b77fed026b154d16493e5deab78f02ec',
-                         self._refs['refs/some/ref'])
+        RefsContainerTests.test_setitem(self)
         f = open(os.path.join(self._refs.path, 'refs', 'some', 'ref'), 'rb')
         self.assertEqual('42d06bd4b77fed026b154d16493e5deab78f02ec',
                           f.read()[:40])
@@ -461,50 +656,41 @@ class RefsContainerTests(unittest.TestCase):
         f.close()
 
     def test_set_if_equals(self):
-        nines = '9' * 40
-        self.assertFalse(self._refs.set_if_equals('HEAD', 'c0ffee', nines))
-        self.assertEqual('42d06bd4b77fed026b154d16493e5deab78f02ec',
-                         self._refs['HEAD'])
-
-        self.assertTrue(self._refs.set_if_equals(
-          'HEAD', '42d06bd4b77fed026b154d16493e5deab78f02ec', nines))
-        self.assertEqual(nines, self._refs['HEAD'])
+        RefsContainerTests.test_set_if_equals(self)
 
         # ensure symref was followed
-        self.assertEqual(nines, self._refs['refs/heads/master'])
+        self.assertEqual('9' * 40, self._refs['refs/heads/master'])
 
+        # ensure lockfile was deleted
         self.assertFalse(os.path.exists(
           os.path.join(self._refs.path, 'refs', 'heads', 'master.lock')))
         self.assertFalse(os.path.exists(
           os.path.join(self._refs.path, 'HEAD.lock')))
 
-    def test_add_if_new(self):
-        nines = '9' * 40
-        self.assertFalse(self._refs.add_if_new('refs/heads/master', nines))
-        self.assertEqual('42d06bd4b77fed026b154d16493e5deab78f02ec',
-                         self._refs['refs/heads/master'])
-
-        self.assertTrue(self._refs.add_if_new('refs/some/ref', nines))
-        self.assertEqual(nines, self._refs['refs/some/ref'])
-
+    def test_add_if_new_packed(self):
         # don't overwrite packed ref
-        self.assertFalse(self._refs.add_if_new('refs/tags/refs-0.1', nines))
+        self.assertFalse(self._refs.add_if_new('refs/tags/refs-0.1', '9' * 40))
         self.assertEqual('df6800012397fb85c56e7418dd4eb9405dee075c',
                          self._refs['refs/tags/refs-0.1'])
 
-    def test_check_refname(self):
-        try:
-            self._refs._check_refname('HEAD')
-        except KeyError:
-            self.fail()
-
-        try:
-            self._refs._check_refname('refs/heads/foo')
-        except KeyError:
-            self.fail()
+    def test_add_if_new_symbolic(self):
+        # Use an empty repo instead of the default.
+        tear_down_repo(self._repo)
+        repo_dir = os.path.join(tempfile.mkdtemp(), 'test')
+        os.makedirs(repo_dir)
+        self._repo = Repo.init(repo_dir)
+        refs = self._repo.refs
 
-        self.assertRaises(KeyError, self._refs._check_refname, 'refs')
-        self.assertRaises(KeyError, self._refs._check_refname, 'notrefs/foo')
+        nines = '9' * 40
+        self.assertEqual('ref: refs/heads/master', refs.read_ref('HEAD'))
+        self.assertFalse('refs/heads/master' in refs)
+        self.assertTrue(refs.add_if_new('HEAD', nines))
+        self.assertEqual('ref: refs/heads/master', refs.read_ref('HEAD'))
+        self.assertEqual(nines, refs['HEAD'])
+        self.assertEqual(nines, refs['refs/heads/master'])
+        self.assertFalse(refs.add_if_new('HEAD', '1' * 40))
+        self.assertEqual(nines, refs['HEAD'])
+        self.assertEqual(nines, refs['refs/heads/master'])
 
     def test_follow(self):
         self.assertEquals(
@@ -516,15 +702,8 @@ class RefsContainerTests(unittest.TestCase):
         self.assertRaises(KeyError, self._refs._follow, 'notrefs/foo')
         self.assertRaises(KeyError, self._refs._follow, 'refs/heads/loop')
 
-    def test_contains(self):
-        self.assertTrue('refs/heads/master' in self._refs)
-        self.assertFalse('refs/heads/bar' in self._refs)
-
     def test_delitem(self):
-        self.assertEqual('42d06bd4b77fed026b154d16493e5deab78f02ec',
-                          self._refs['refs/heads/master'])
-        del self._refs['refs/heads/master']
-        self.assertRaises(KeyError, lambda: self._refs['refs/heads/master'])
+        RefsContainerTests.test_delitem(self)
         ref_file = os.path.join(self._refs.path, 'refs', 'heads', 'master')
         self.assertFalse(os.path.exists(ref_file))
         self.assertFalse('refs/heads/master' in self._refs.get_packed_refs())
@@ -538,12 +717,7 @@ class RefsContainerTests(unittest.TestCase):
                          self._refs['refs/heads/master'])
         self.assertFalse(os.path.exists(os.path.join(self._refs.path, 'HEAD')))
 
-    def test_remove_if_equals(self):
-        nines = '9' * 40
-        self.assertFalse(self._refs.remove_if_equals('HEAD', 'c0ffee'))
-        self.assertEqual('42d06bd4b77fed026b154d16493e5deab78f02ec',
-                         self._refs['HEAD'])
-
+    def test_remove_if_equals_symref(self):
         # HEAD is a symref, so shouldn't equal its dereferenced value
         self.assertFalse(self._refs.remove_if_equals(
           'HEAD', '42d06bd4b77fed026b154d16493e5deab78f02ec'))
@@ -561,6 +735,8 @@ class RefsContainerTests(unittest.TestCase):
         self.assertFalse(os.path.exists(
             os.path.join(self._refs.path, 'HEAD.lock')))
 
+
+    def test_remove_if_equals_packed(self):
         # test removing ref that is only packed
         self.assertEqual('df6800012397fb85c56e7418dd4eb9405dee075c',
                          self._refs['refs/tags/refs-0.1'])
@@ -575,4 +751,3 @@ class RefsContainerTests(unittest.TestCase):
             self._refs.read_ref("refs/heads/packed"))
         self.assertEqual(None,
             self._refs.read_ref("nonexistant"))
-