Merge improvements and extra tests, mainly to deal better with creating non-bare...
[jelmer/dulwich-libgit2.git] / dulwich / tests / test_repository.py
1 # test_repository.py -- tests for repository.py
2 # Copyright (C) 2007 James Westby <jw+debian@jameswestby.net>
3
4 # This program is free software; you can redistribute it and/or
5 # modify it under the terms of the GNU General Public License
6 # as published by the Free Software Foundation; version 2
7 # of the License or (at your option) any later version of 
8 # the License.
9
10 # This program is distributed in the hope that it will be useful,
11 # but WITHOUT ANY WARRANTY; without even the implied warranty of
12 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 # GNU General Public License for more details.
14
15 # You should have received a copy of the GNU General Public License
16 # along with this program; if not, write to the Free Software
17 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
18 # MA  02110-1301, USA.
19
20
21 """Tests for the repository."""
22
23 from cStringIO import StringIO
24 import os
25 import shutil
26 import tempfile
27 import unittest
28 import warnings
29
30 from dulwich import errors
31 from dulwich.object_store import (
32     tree_lookup_path,
33     )
34 from dulwich import objects
35 from dulwich.repo import (
36     check_ref_format,
37     DictRefsContainer,
38     Repo,
39     read_packed_refs,
40     read_packed_refs_with_peeled,
41     write_packed_refs,
42     _split_ref_line,
43     )
44 from dulwich.tests.utils import (
45     open_repo,
46     tear_down_repo,
47     )
48
49 missing_sha = 'b91fa4d900e17e99b433218e988c4eb4a3e9a097'
50
51
52 class CreateRepositoryTests(unittest.TestCase):
53
54     def test_create(self):
55         tmp_dir = tempfile.mkdtemp()
56         try:
57             repo = Repo.init_bare(tmp_dir)
58             self.assertEquals(tmp_dir, repo._controldir)
59         finally:
60             shutil.rmtree(tmp_dir)
61
62
63 class RepositoryTests(unittest.TestCase):
64
65     def setUp(self):
66         self._repo = None
67
68     def tearDown(self):
69         if self._repo is not None:
70             tear_down_repo(self._repo)
71   
72     def test_simple_props(self):
73         r = self._repo = open_repo('a.git')
74         self.assertEqual(r.controldir(), r.path)
75   
76     def test_ref(self):
77         r = self._repo = open_repo('a.git')
78         self.assertEqual(r.ref('refs/heads/master'),
79                          'a90fa2d900a17e99b433217e988c4eb4a2e9a097')
80
81     def test_setitem(self):
82         r = self._repo = open_repo('a.git')
83         r["refs/tags/foo"] = 'a90fa2d900a17e99b433217e988c4eb4a2e9a097'
84         self.assertEquals('a90fa2d900a17e99b433217e988c4eb4a2e9a097',
85                           r["refs/tags/foo"].id)
86   
87     def test_get_refs(self):
88         r = self._repo = open_repo('a.git')
89         self.assertEqual({
90             'HEAD': 'a90fa2d900a17e99b433217e988c4eb4a2e9a097',
91             'refs/heads/master': 'a90fa2d900a17e99b433217e988c4eb4a2e9a097',
92             'refs/tags/mytag': '28237f4dc30d0d462658d6b937b08a0f0b6ef55a',
93             'refs/tags/mytag-packed': 'b0931cadc54336e78a1d980420e3268903b57a50',
94             }, r.get_refs())
95   
96     def test_head(self):
97         r = self._repo = open_repo('a.git')
98         self.assertEqual(r.head(), 'a90fa2d900a17e99b433217e988c4eb4a2e9a097')
99
100     def test_get_object(self):
101         r = self._repo = open_repo('a.git')
102         obj = r.get_object(r.head())
103         self.assertEqual(obj.type_name, 'commit')
104
105     def test_get_object_non_existant(self):
106         r = self._repo = open_repo('a.git')
107         self.assertRaises(KeyError, r.get_object, missing_sha)
108
109     def test_contains_object(self):
110         r = self._repo = open_repo('a.git')
111         self.assertTrue(r.head() in r)
112
113     def test_contains_ref(self):
114         r = self._repo = open_repo('a.git')
115         self.assertTrue("HEAD" in r)
116
117     def test_contains_missing(self):
118         r = self._repo = open_repo('a.git')
119         self.assertFalse("bar" in r)
120
121     def test_commit(self):
122         r = self._repo = open_repo('a.git')
123         warnings.simplefilter("ignore", DeprecationWarning)
124         try:
125             obj = r.commit(r.head())
126         finally:
127             warnings.resetwarnings()
128         self.assertEqual(obj.type_name, 'commit')
129
130     def test_commit_not_commit(self):
131         r = self._repo = open_repo('a.git')
132         warnings.simplefilter("ignore", DeprecationWarning)
133         try:
134             self.assertRaises(errors.NotCommitError,
135                 r.commit, '4f2e6529203aa6d44b5af6e3292c837ceda003f9')
136         finally:
137             warnings.resetwarnings()
138
139     def test_tree(self):
140         r = self._repo = open_repo('a.git')
141         commit = r[r.head()]
142         warnings.simplefilter("ignore", DeprecationWarning)
143         try:
144             tree = r.tree(commit.tree)
145         finally:
146             warnings.resetwarnings()
147         self.assertEqual(tree.type_name, 'tree')
148         self.assertEqual(tree.sha().hexdigest(), commit.tree)
149
150     def test_tree_not_tree(self):
151         r = self._repo = open_repo('a.git')
152         warnings.simplefilter("ignore", DeprecationWarning)
153         try:
154             self.assertRaises(errors.NotTreeError, r.tree, r.head())
155         finally:
156             warnings.resetwarnings()
157
158     def test_tag(self):
159         r = self._repo = open_repo('a.git')
160         tag_sha = '28237f4dc30d0d462658d6b937b08a0f0b6ef55a'
161         warnings.simplefilter("ignore", DeprecationWarning)
162         try:
163             tag = r.tag(tag_sha)
164         finally:
165             warnings.resetwarnings()
166         self.assertEqual(tag.type_name, 'tag')
167         self.assertEqual(tag.sha().hexdigest(), tag_sha)
168         obj_class, obj_sha = tag.object
169         self.assertEqual(obj_class, objects.Commit)
170         self.assertEqual(obj_sha, r.head())
171
172     def test_tag_not_tag(self):
173         r = self._repo = open_repo('a.git')
174         warnings.simplefilter("ignore", DeprecationWarning)
175         try:
176             self.assertRaises(errors.NotTagError, r.tag, r.head())
177         finally:
178             warnings.resetwarnings()
179
180     def test_get_peeled(self):
181         # unpacked ref
182         r = self._repo = open_repo('a.git')
183         tag_sha = '28237f4dc30d0d462658d6b937b08a0f0b6ef55a'
184         self.assertNotEqual(r[tag_sha].sha().hexdigest(), r.head())
185         self.assertEqual(r.get_peeled('refs/tags/mytag'), r.head())
186
187         # packed ref with cached peeled value
188         packed_tag_sha = 'b0931cadc54336e78a1d980420e3268903b57a50'
189         parent_sha = r[r.head()].parents[0]
190         self.assertNotEqual(r[packed_tag_sha].sha().hexdigest(), parent_sha)
191         self.assertEqual(r.get_peeled('refs/tags/mytag-packed'), parent_sha)
192
193         # TODO: add more corner cases to test repo
194
195     def test_get_peeled_not_tag(self):
196         r = self._repo = open_repo('a.git')
197         self.assertEqual(r.get_peeled('HEAD'), r.head())
198
199     def test_get_blob(self):
200         r = self._repo = open_repo('a.git')
201         commit = r[r.head()]
202         tree = r[commit.tree]
203         blob_sha = tree.entries()[0][2]
204         warnings.simplefilter("ignore", DeprecationWarning)
205         try:
206             blob = r.get_blob(blob_sha)
207         finally:
208             warnings.resetwarnings()
209         self.assertEqual(blob.type_name, 'blob')
210         self.assertEqual(blob.sha().hexdigest(), blob_sha)
211
212     def test_get_blob_notblob(self):
213         r = self._repo = open_repo('a.git')
214         warnings.simplefilter("ignore", DeprecationWarning)
215         try:
216             self.assertRaises(errors.NotBlobError, r.get_blob, r.head())
217         finally:
218             warnings.resetwarnings()
219     
220     def test_linear_history(self):
221         r = self._repo = open_repo('a.git')
222         history = r.revision_history(r.head())
223         shas = [c.sha().hexdigest() for c in history]
224         self.assertEqual(shas, [r.head(),
225                                 '2a72d929692c41d8554c07f6301757ba18a65d91'])
226   
227     def test_merge_history(self):
228         r = self._repo = open_repo('simple_merge.git')
229         history = r.revision_history(r.head())
230         shas = [c.sha().hexdigest() for c in history]
231         self.assertEqual(shas, ['5dac377bdded4c9aeb8dff595f0faeebcc8498cc',
232                                 'ab64bbdcc51b170d21588e5c5d391ee5c0c96dfd',
233                                 '4cffe90e0a41ad3f5190079d7c8f036bde29cbe6',
234                                 '60dacdc733de308bb77bb76ce0fb0f9b44c9769e',
235                                 '0d89f20333fbb1d2f3a94da77f4981373d8f4310'])
236   
237     def test_revision_history_missing_commit(self):
238         r = self._repo = open_repo('simple_merge.git')
239         self.assertRaises(errors.MissingCommitError, r.revision_history,
240                           missing_sha)
241   
242     def test_out_of_order_merge(self):
243         """Test that revision history is ordered by date, not parent order."""
244         r = self._repo = open_repo('ooo_merge.git')
245         history = r.revision_history(r.head())
246         shas = [c.sha().hexdigest() for c in history]
247         self.assertEqual(shas, ['7601d7f6231db6a57f7bbb79ee52e4d462fd44d1',
248                                 'f507291b64138b875c28e03469025b1ea20bc614',
249                                 'fb5b0425c7ce46959bec94d54b9a157645e114f5',
250                                 'f9e39b120c68182a4ba35349f832d0e4e61f485c'])
251   
252     def test_get_tags_empty(self):
253         r = self._repo = open_repo('ooo_merge.git')
254         self.assertEqual({}, r.refs.as_dict('refs/tags'))
255
256     def test_get_config(self):
257         r = self._repo = open_repo('ooo_merge.git')
258         self.assertEquals({}, r.get_config())
259
260     def test_common_revisions(self):
261         """
262         This test demonstrates that ``find_common_revisions()`` actually returns
263         common heads, not revisions; dulwich already uses
264         ``find_common_revisions()`` in such a manner (see
265         ``Repo.fetch_objects()``).
266         """
267
268         expected_shas = set(['60dacdc733de308bb77bb76ce0fb0f9b44c9769e'])
269
270         # Source for objects.
271         r_base = open_repo('simple_merge.git')
272
273         # Re-create each-side of the merge in simple_merge.git.
274         #
275         # Since the trees and blobs are missing, the repository created is
276         # corrupted, but we're only checking for commits for the purpose of this
277         # test, so it's immaterial.
278         r1_dir = tempfile.mkdtemp()
279         r1_commits = ['ab64bbdcc51b170d21588e5c5d391ee5c0c96dfd', # HEAD
280                       '60dacdc733de308bb77bb76ce0fb0f9b44c9769e',
281                       '0d89f20333fbb1d2f3a94da77f4981373d8f4310']
282
283         r2_dir = tempfile.mkdtemp()
284         r2_commits = ['4cffe90e0a41ad3f5190079d7c8f036bde29cbe6', # HEAD
285                       '60dacdc733de308bb77bb76ce0fb0f9b44c9769e',
286                       '0d89f20333fbb1d2f3a94da77f4981373d8f4310']
287
288         try:
289             r1 = Repo.init_bare(r1_dir)
290             map(lambda c: r1.object_store.add_object(r_base.get_object(c)), \
291                 r1_commits)
292             r1.refs['HEAD'] = r1_commits[0]
293
294             r2 = Repo.init_bare(r2_dir)
295             map(lambda c: r2.object_store.add_object(r_base.get_object(c)), \
296                 r2_commits)
297             r2.refs['HEAD'] = r2_commits[0]
298
299             # Finally, the 'real' testing!
300             shas = r2.object_store.find_common_revisions(r1.get_graph_walker())
301             self.assertEqual(set(shas), expected_shas)
302
303             shas = r1.object_store.find_common_revisions(r2.get_graph_walker())
304             self.assertEqual(set(shas), expected_shas)
305         finally:
306             shutil.rmtree(r1_dir)
307             shutil.rmtree(r2_dir)
308
309
310 class BuildRepoTests(unittest.TestCase):
311     """Tests that build on-disk repos from scratch.
312
313     Repos live in a temp dir and are torn down after each test. They start with
314     a single commit in master having single file named 'a'.
315     """
316
317     def setUp(self):
318         repo_dir = os.path.join(tempfile.mkdtemp(), 'test')
319         os.makedirs(repo_dir)
320         r = self._repo = Repo.init(repo_dir)
321         self.assertFalse(r.bare)
322         self.assertEqual('ref: refs/heads/master', r.refs.read_ref('HEAD'))
323         self.assertRaises(KeyError, lambda: r.refs['refs/heads/master'])
324
325         f = open(os.path.join(r.path, 'a'), 'wb')
326         try:
327             f.write('file contents')
328         finally:
329             f.close()
330         r.stage(['a'])
331         commit_sha = r.do_commit('msg',
332                                  committer='Test Committer <test@nodomain.com>',
333                                  author='Test Author <test@nodomain.com>',
334                                  commit_timestamp=12345, commit_timezone=0,
335                                  author_timestamp=12345, author_timezone=0)
336         self.assertEqual([], r[commit_sha].parents)
337         self._root_commit = commit_sha
338
339     def tearDown(self):
340         tear_down_repo(self._repo)
341
342     def test_build_repo(self):
343         r = self._repo
344         self.assertEqual('ref: refs/heads/master', r.refs.read_ref('HEAD'))
345         self.assertEqual(self._root_commit, r.refs['refs/heads/master'])
346         expected_blob = objects.Blob.from_string('file contents')
347         self.assertEqual(expected_blob.data, r[expected_blob.id].data)
348         actual_commit = r[self._root_commit]
349         self.assertEqual('msg', actual_commit.message)
350
351     def test_commit_modified(self):
352         r = self._repo
353         f = open(os.path.join(r.path, 'a'), 'wb')
354         try:
355             f.write('new contents')
356         finally:
357             f.close()
358         r.stage(['a'])
359         commit_sha = r.do_commit('modified a',
360                                  committer='Test Committer <test@nodomain.com>',
361                                  author='Test Author <test@nodomain.com>',
362                                  commit_timestamp=12395, commit_timezone=0,
363                                  author_timestamp=12395, author_timezone=0)
364         self.assertEqual([self._root_commit], r[commit_sha].parents)
365         _, blob_id = tree_lookup_path(r.get_object, r[commit_sha].tree, 'a')
366         self.assertEqual('new contents', r[blob_id].data)
367
368     def test_commit_deleted(self):
369         r = self._repo
370         os.remove(os.path.join(r.path, 'a'))
371         r.stage(['a'])
372         commit_sha = r.do_commit('deleted a',
373                                  committer='Test Committer <test@nodomain.com>',
374                                  author='Test Author <test@nodomain.com>',
375                                  commit_timestamp=12395, commit_timezone=0,
376                                  author_timestamp=12395, author_timezone=0)
377         self.assertEqual([self._root_commit], r[commit_sha].parents)
378         self.assertEqual([], list(r.open_index()))
379         tree = r[r[commit_sha].tree]
380         self.assertEqual([], tree.iteritems())
381
382     def test_commit_fail_ref(self):
383         r = self._repo
384
385         def set_if_equals(name, old_ref, new_ref):
386             return False
387         r.refs.set_if_equals = set_if_equals
388
389         def add_if_new(name, new_ref):
390             self.fail('Unexpected call to add_if_new')
391         r.refs.add_if_new = add_if_new
392
393         old_shas = set(r.object_store)
394         self.assertRaises(errors.CommitError, r.do_commit, 'failed commit',
395                           committer='Test Committer <test@nodomain.com>',
396                           author='Test Author <test@nodomain.com>',
397                           commit_timestamp=12345, commit_timezone=0,
398                           author_timestamp=12345, author_timezone=0)
399         new_shas = set(r.object_store) - old_shas
400         self.assertEqual(1, len(new_shas))
401         # Check that the new commit (now garbage) was added.
402         new_commit = r[new_shas.pop()]
403         self.assertEqual(r[self._root_commit].tree, new_commit.tree)
404         self.assertEqual('failed commit', new_commit.message)
405
406     def test_stage_deleted(self):
407         r = self._repo
408         os.remove(os.path.join(r.path, 'a'))
409         r.stage(['a'])
410         r.stage(['a'])  # double-stage a deleted path
411
412
413 class CheckRefFormatTests(unittest.TestCase):
414     """Tests for the check_ref_format function.
415
416     These are the same tests as in the git test suite.
417     """
418
419     def test_valid(self):
420         self.assertTrue(check_ref_format('heads/foo'))
421         self.assertTrue(check_ref_format('foo/bar/baz'))
422         self.assertTrue(check_ref_format('refs///heads/foo'))
423         self.assertTrue(check_ref_format('foo./bar'))
424         self.assertTrue(check_ref_format('heads/foo@bar'))
425         self.assertTrue(check_ref_format('heads/fix.lock.error'))
426
427     def test_invalid(self):
428         self.assertFalse(check_ref_format('foo'))
429         self.assertFalse(check_ref_format('heads/foo/'))
430         self.assertFalse(check_ref_format('./foo'))
431         self.assertFalse(check_ref_format('.refs/foo'))
432         self.assertFalse(check_ref_format('heads/foo..bar'))
433         self.assertFalse(check_ref_format('heads/foo?bar'))
434         self.assertFalse(check_ref_format('heads/foo.lock'))
435         self.assertFalse(check_ref_format('heads/v@{ation'))
436         self.assertFalse(check_ref_format('heads/foo\bar'))
437
438
439 ONES = "1" * 40
440 TWOS = "2" * 40
441 THREES = "3" * 40
442 FOURS = "4" * 40
443
444 class PackedRefsFileTests(unittest.TestCase):
445
446     def test_split_ref_line_errors(self):
447         self.assertRaises(errors.PackedRefsException, _split_ref_line,
448                           'singlefield')
449         self.assertRaises(errors.PackedRefsException, _split_ref_line,
450                           'badsha name')
451         self.assertRaises(errors.PackedRefsException, _split_ref_line,
452                           '%s bad/../refname' % ONES)
453
454     def test_read_without_peeled(self):
455         f = StringIO('# comment\n%s ref/1\n%s ref/2' % (ONES, TWOS))
456         self.assertEqual([(ONES, 'ref/1'), (TWOS, 'ref/2')],
457                          list(read_packed_refs(f)))
458
459     def test_read_without_peeled_errors(self):
460         f = StringIO('%s ref/1\n^%s' % (ONES, TWOS))
461         self.assertRaises(errors.PackedRefsException, list, read_packed_refs(f))
462
463     def test_read_with_peeled(self):
464         f = StringIO('%s ref/1\n%s ref/2\n^%s\n%s ref/4' % (
465           ONES, TWOS, THREES, FOURS))
466         self.assertEqual([
467           (ONES, 'ref/1', None),
468           (TWOS, 'ref/2', THREES),
469           (FOURS, 'ref/4', None),
470           ], list(read_packed_refs_with_peeled(f)))
471
472     def test_read_with_peeled_errors(self):
473         f = StringIO('^%s\n%s ref/1' % (TWOS, ONES))
474         self.assertRaises(errors.PackedRefsException, list, read_packed_refs(f))
475
476         f = StringIO('%s ref/1\n^%s\n^%s' % (ONES, TWOS, THREES))
477         self.assertRaises(errors.PackedRefsException, list, read_packed_refs(f))
478
479     def test_write_with_peeled(self):
480         f = StringIO()
481         write_packed_refs(f, {'ref/1': ONES, 'ref/2': TWOS},
482                           {'ref/1': THREES})
483         self.assertEqual(
484           "# pack-refs with: peeled\n%s ref/1\n^%s\n%s ref/2\n" % (
485           ONES, THREES, TWOS), f.getvalue())
486
487     def test_write_without_peeled(self):
488         f = StringIO()
489         write_packed_refs(f, {'ref/1': ONES, 'ref/2': TWOS})
490         self.assertEqual("%s ref/1\n%s ref/2\n" % (ONES, TWOS), f.getvalue())
491
492
493 # Dict of refs that we expect all RefsContainerTests subclasses to define.
494 _TEST_REFS = {
495   'HEAD': '42d06bd4b77fed026b154d16493e5deab78f02ec',
496   'refs/heads/master': '42d06bd4b77fed026b154d16493e5deab78f02ec',
497   'refs/heads/packed': '42d06bd4b77fed026b154d16493e5deab78f02ec',
498   'refs/tags/refs-0.1': 'df6800012397fb85c56e7418dd4eb9405dee075c',
499   'refs/tags/refs-0.2': '3ec9c43c84ff242e3ef4a9fc5bc111fd780a76a8',
500   }
501
502
503 class RefsContainerTests(object):
504
505     def test_keys(self):
506         actual_keys = set(self._refs.keys())
507         self.assertEqual(set(self._refs.allkeys()), actual_keys)
508         # ignore the symref loop if it exists
509         actual_keys.discard('refs/heads/loop')
510         self.assertEqual(set(_TEST_REFS.iterkeys()), actual_keys)
511
512         actual_keys = self._refs.keys('refs/heads')
513         actual_keys.discard('loop')
514         self.assertEqual(['master', 'packed'], sorted(actual_keys))
515         self.assertEqual(['refs-0.1', 'refs-0.2'],
516                          sorted(self._refs.keys('refs/tags')))
517
518     def test_as_dict(self):
519         # refs/heads/loop does not show up even if it exists
520         self.assertEqual(_TEST_REFS, self._refs.as_dict())
521
522     def test_setitem(self):
523         self._refs['refs/some/ref'] = '42d06bd4b77fed026b154d16493e5deab78f02ec'
524         self.assertEqual('42d06bd4b77fed026b154d16493e5deab78f02ec',
525                          self._refs['refs/some/ref'])
526
527     def test_set_if_equals(self):
528         nines = '9' * 40
529         self.assertFalse(self._refs.set_if_equals('HEAD', 'c0ffee', nines))
530         self.assertEqual('42d06bd4b77fed026b154d16493e5deab78f02ec',
531                          self._refs['HEAD'])
532
533         self.assertTrue(self._refs.set_if_equals(
534           'HEAD', '42d06bd4b77fed026b154d16493e5deab78f02ec', nines))
535         self.assertEqual(nines, self._refs['HEAD'])
536
537         self.assertTrue(self._refs.set_if_equals('refs/heads/master', None,
538                                                  nines))
539         self.assertEqual(nines, self._refs['refs/heads/master'])
540
541     def test_add_if_new(self):
542         nines = '9' * 40
543         self.assertFalse(self._refs.add_if_new('refs/heads/master', nines))
544         self.assertEqual('42d06bd4b77fed026b154d16493e5deab78f02ec',
545                          self._refs['refs/heads/master'])
546
547         self.assertTrue(self._refs.add_if_new('refs/some/ref', nines))
548         self.assertEqual(nines, self._refs['refs/some/ref'])
549
550     def test_set_symbolic_ref(self):
551         self._refs.set_symbolic_ref('refs/heads/symbolic', 'refs/heads/master')
552         self.assertEqual('ref: refs/heads/master',
553                          self._refs.read_loose_ref('refs/heads/symbolic'))
554         self.assertEqual('42d06bd4b77fed026b154d16493e5deab78f02ec',
555                          self._refs['refs/heads/symbolic'])
556
557     def test_set_symbolic_ref_overwrite(self):
558         nines = '9' * 40
559         self.assertFalse('refs/heads/symbolic' in self._refs)
560         self._refs['refs/heads/symbolic'] = nines
561         self.assertEqual(nines, self._refs.read_loose_ref('refs/heads/symbolic'))
562         self._refs.set_symbolic_ref('refs/heads/symbolic', 'refs/heads/master')
563         self.assertEqual('ref: refs/heads/master',
564                          self._refs.read_loose_ref('refs/heads/symbolic'))
565         self.assertEqual('42d06bd4b77fed026b154d16493e5deab78f02ec',
566                          self._refs['refs/heads/symbolic'])
567
568     def test_check_refname(self):
569         try:
570             self._refs._check_refname('HEAD')
571         except KeyError:
572             self.fail()
573
574         try:
575             self._refs._check_refname('refs/heads/foo')
576         except KeyError:
577             self.fail()
578
579         self.assertRaises(KeyError, self._refs._check_refname, 'refs')
580         self.assertRaises(KeyError, self._refs._check_refname, 'notrefs/foo')
581
582     def test_contains(self):
583         self.assertTrue('refs/heads/master' in self._refs)
584         self.assertFalse('refs/heads/bar' in self._refs)
585
586     def test_delitem(self):
587         self.assertEqual('42d06bd4b77fed026b154d16493e5deab78f02ec',
588                           self._refs['refs/heads/master'])
589         del self._refs['refs/heads/master']
590         self.assertRaises(KeyError, lambda: self._refs['refs/heads/master'])
591
592     def test_remove_if_equals(self):
593         self.assertFalse(self._refs.remove_if_equals('HEAD', 'c0ffee'))
594         self.assertEqual('42d06bd4b77fed026b154d16493e5deab78f02ec',
595                          self._refs['HEAD'])
596         self.assertTrue(self._refs.remove_if_equals(
597           'refs/tags/refs-0.2', '3ec9c43c84ff242e3ef4a9fc5bc111fd780a76a8'))
598         self.assertFalse('refs/tags/refs-0.2' in self._refs)
599
600
601 class DictRefsContainerTests(RefsContainerTests, unittest.TestCase):
602
603     def setUp(self):
604         self._refs = DictRefsContainer(dict(_TEST_REFS))
605
606
607 class DiskRefsContainerTests(RefsContainerTests, unittest.TestCase):
608
609     def setUp(self):
610         self._repo = open_repo('refs.git')
611         self._refs = self._repo.refs
612
613     def tearDown(self):
614         tear_down_repo(self._repo)
615
616     def test_get_packed_refs(self):
617         self.assertEqual({
618           'refs/heads/packed': '42d06bd4b77fed026b154d16493e5deab78f02ec',
619           'refs/tags/refs-0.1': 'df6800012397fb85c56e7418dd4eb9405dee075c',
620           }, self._refs.get_packed_refs())
621
622     def test_get_peeled_not_packed(self):
623         # not packed
624         self.assertEqual(None, self._refs.get_peeled('refs/tags/refs-0.2'))
625         self.assertEqual('3ec9c43c84ff242e3ef4a9fc5bc111fd780a76a8',
626                          self._refs['refs/tags/refs-0.2'])
627
628         # packed, known not peelable
629         self.assertEqual(self._refs['refs/heads/packed'],
630                          self._refs.get_peeled('refs/heads/packed'))
631
632         # packed, peeled
633         self.assertEqual('42d06bd4b77fed026b154d16493e5deab78f02ec',
634                          self._refs.get_peeled('refs/tags/refs-0.1'))
635
636     def test_setitem(self):
637         RefsContainerTests.test_setitem(self)
638         f = open(os.path.join(self._refs.path, 'refs', 'some', 'ref'), 'rb')
639         self.assertEqual('42d06bd4b77fed026b154d16493e5deab78f02ec',
640                           f.read()[:40])
641         f.close()
642
643     def test_setitem_symbolic(self):
644         ones = '1' * 40
645         self._refs['HEAD'] = ones
646         self.assertEqual(ones, self._refs['HEAD'])
647
648         # ensure HEAD was not modified
649         f = open(os.path.join(self._refs.path, 'HEAD'), 'rb')
650         self.assertEqual('ref: refs/heads/master', iter(f).next().rstrip('\n'))
651         f.close()
652
653         # ensure the symbolic link was written through
654         f = open(os.path.join(self._refs.path, 'refs', 'heads', 'master'), 'rb')
655         self.assertEqual(ones, f.read()[:40])
656         f.close()
657
658     def test_set_if_equals(self):
659         RefsContainerTests.test_set_if_equals(self)
660
661         # ensure symref was followed
662         self.assertEqual('9' * 40, self._refs['refs/heads/master'])
663
664         # ensure lockfile was deleted
665         self.assertFalse(os.path.exists(
666           os.path.join(self._refs.path, 'refs', 'heads', 'master.lock')))
667         self.assertFalse(os.path.exists(
668           os.path.join(self._refs.path, 'HEAD.lock')))
669
670     def test_add_if_new_packed(self):
671         # don't overwrite packed ref
672         self.assertFalse(self._refs.add_if_new('refs/tags/refs-0.1', '9' * 40))
673         self.assertEqual('df6800012397fb85c56e7418dd4eb9405dee075c',
674                          self._refs['refs/tags/refs-0.1'])
675
676     def test_add_if_new_symbolic(self):
677         # Use an empty repo instead of the default.
678         tear_down_repo(self._repo)
679         repo_dir = os.path.join(tempfile.mkdtemp(), 'test')
680         os.makedirs(repo_dir)
681         self._repo = Repo.init(repo_dir)
682         refs = self._repo.refs
683
684         nines = '9' * 40
685         self.assertEqual('ref: refs/heads/master', refs.read_ref('HEAD'))
686         self.assertFalse('refs/heads/master' in refs)
687         self.assertTrue(refs.add_if_new('HEAD', nines))
688         self.assertEqual('ref: refs/heads/master', refs.read_ref('HEAD'))
689         self.assertEqual(nines, refs['HEAD'])
690         self.assertEqual(nines, refs['refs/heads/master'])
691         self.assertFalse(refs.add_if_new('HEAD', '1' * 40))
692         self.assertEqual(nines, refs['HEAD'])
693         self.assertEqual(nines, refs['refs/heads/master'])
694
695     def test_follow(self):
696         self.assertEquals(
697           ('refs/heads/master', '42d06bd4b77fed026b154d16493e5deab78f02ec'),
698           self._refs._follow('HEAD'))
699         self.assertEquals(
700           ('refs/heads/master', '42d06bd4b77fed026b154d16493e5deab78f02ec'),
701           self._refs._follow('refs/heads/master'))
702         self.assertRaises(KeyError, self._refs._follow, 'notrefs/foo')
703         self.assertRaises(KeyError, self._refs._follow, 'refs/heads/loop')
704
705     def test_delitem(self):
706         RefsContainerTests.test_delitem(self)
707         ref_file = os.path.join(self._refs.path, 'refs', 'heads', 'master')
708         self.assertFalse(os.path.exists(ref_file))
709         self.assertFalse('refs/heads/master' in self._refs.get_packed_refs())
710
711     def test_delitem_symbolic(self):
712         self.assertEqual('ref: refs/heads/master',
713                           self._refs.read_loose_ref('HEAD'))
714         del self._refs['HEAD']
715         self.assertRaises(KeyError, lambda: self._refs['HEAD'])
716         self.assertEqual('42d06bd4b77fed026b154d16493e5deab78f02ec',
717                          self._refs['refs/heads/master'])
718         self.assertFalse(os.path.exists(os.path.join(self._refs.path, 'HEAD')))
719
720     def test_remove_if_equals_symref(self):
721         # HEAD is a symref, so shouldn't equal its dereferenced value
722         self.assertFalse(self._refs.remove_if_equals(
723           'HEAD', '42d06bd4b77fed026b154d16493e5deab78f02ec'))
724         self.assertTrue(self._refs.remove_if_equals(
725           'refs/heads/master', '42d06bd4b77fed026b154d16493e5deab78f02ec'))
726         self.assertRaises(KeyError, lambda: self._refs['refs/heads/master'])
727
728         # HEAD is now a broken symref
729         self.assertRaises(KeyError, lambda: self._refs['HEAD'])
730         self.assertEqual('ref: refs/heads/master',
731                           self._refs.read_loose_ref('HEAD'))
732
733         self.assertFalse(os.path.exists(
734             os.path.join(self._refs.path, 'refs', 'heads', 'master.lock')))
735         self.assertFalse(os.path.exists(
736             os.path.join(self._refs.path, 'HEAD.lock')))
737
738
739     def test_remove_if_equals_packed(self):
740         # test removing ref that is only packed
741         self.assertEqual('df6800012397fb85c56e7418dd4eb9405dee075c',
742                          self._refs['refs/tags/refs-0.1'])
743         self.assertTrue(
744           self._refs.remove_if_equals('refs/tags/refs-0.1',
745           'df6800012397fb85c56e7418dd4eb9405dee075c'))
746         self.assertRaises(KeyError, lambda: self._refs['refs/tags/refs-0.1'])
747
748     def test_read_ref(self):
749         self.assertEqual('ref: refs/heads/master', self._refs.read_ref("HEAD"))
750         self.assertEqual('42d06bd4b77fed026b154d16493e5deab78f02ec', 
751             self._refs.read_ref("refs/heads/packed"))
752         self.assertEqual(None,
753             self._refs.read_ref("nonexistant"))